Introduction¶
TP53, known as the Guardian of the Genome, is the most commonly mutated gene in human cancers. Understanding its regulation is critical. In the "A machine learning and directed network optimization approach to uncover TP53 regulatory patterns" two strategies were used to explore this: machine learning to predict TP53 mutation status from transcriptomic data, and directed regulatory networks to analyze the impact of these mutations on TP53 target gene expression.
In this notebook we provide our own solution to the first of the two problems, trying to understand the mutation type of a certain cell, as well as if a certain mutation compromises the functionality of the cell itself.
We begin by inspecting simple but informative statistics about the dataset, and by visualizing the distributions of some cells and genes. Then, we resort to unsupervised machine learning methods, specifically clustering and dimensionality reduction, to gain further insights into the structure of the data. Next, we train various types of classifiers to accomplish the two tasks. These include classical machine learning methods like logitstic regression and KNN, ensemble methods like random forests, and kernel methods like kernel SVM. To tackle overfitting and reduce noise, we do feature selection through correlation analysis as well as feature extraction through PCA. Finally, we resort to more sophisticated approaches. First, a Graph Neural Network, which however is seen to suffer from the limited training set dimension. Second, to exploit powerful deep neural networks in our limited data setting, we try a transfer learning approach, training a simple classifier on top of frozen representations from scGPT, a foundation model for single cell sequencing analysis that was trained on tens of millions of datapoints.
Exploratory Data Analysis¶
Data comes from the The Cancer Genome Atlas Program and contains different cancer cells where TP53 is mutated. The data contains information about the type of mutation (missense or others) and wheter the cell was compromised.
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sklearn
import json
import os
import sys
import warnings
import torch
import scanpy as sc
import networkx as nx
import tqdm
import gseapy as gp
import torchtext
torchtext.disable_torchtext_deprecation_warning()
sys.path.insert(0, "../")
import scgpt as scg
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.model import TransformerModel
from scgpt.preprocess import Preprocessor
from scgpt.utils import set_seed
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')
/home/dario/PycharmProjects/ML-lab/venv/lib/python3.12/site-packages/scgpt/model/model.py:21: UserWarning: flash_attn is not installed
warnings.warn("flash_attn is not installed")
/home/dario/PycharmProjects/ML-lab/venv/lib/python3.12/site-packages/scgpt/model/multiomic_model.py:19: UserWarning: flash_attn is not installed
warnings.warn("flash_attn is not installed")
csv_file = 'data/TCGA_labels.csv'
df = pd.read_csv(csv_file)
df.head()
| Variant_Classification | ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | ... | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | is_true | mutation | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... | 376.831000 | 1358.86000 | 2471.580000 | 143602.00000 | 159.674000 | 63.136500 | 946.639000 | 626.477000 | 344.195000 | ... | 323.344000 | 75.356400 | 8558.040000 | 43.991900 | 1783.300000 | 5320.570000 | 1018.330000 | 821.181000 | True | Frame_Shift_Ins |
| 1 | A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... | 198.244448 | 5367.62179 | 2528.570328 | 77726.97678 | 19.656121 | 2.579692 | 2130.976296 | 732.991931 | 386.605718 | ... | 228.638412 | 322.247574 | 6446.509718 | 36.542642 | 3207.438557 | 3213.116903 | 1688.261865 | 1149.407697 | True | In_Frame_Del |
| 2 | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | 117.516000 | 1936.34000 | 14533.700000 | 185841.00000 | 95.490700 | 191.866000 | 766.578000 | 256.410000 | 239.611000 | ... | 230.672000 | 121.132000 | 12726.800000 | 74.270600 | 2496.910000 | 4005.300000 | 923.961000 | 391.689000 | True | Frame_Shift_Del |
| 3 | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | 60.747000 | 5667.60000 | 3560.420000 | 107645.00000 | 86.834700 | 1047.620000 | 698.413000 | 186.741000 | 262.372000 | ... | 638.609000 | 343.604000 | 8024.280000 | 78.431400 | 3746.030000 | 2692.810000 | 1168.070000 | 670.402000 | True | Frame_Shift_Del |
| 4 | A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... | 327.477000 | 1096.61000 | 3430.480000 | 64166.60000 | 51.837300 | 9.491300 | 706.010000 | 1617.540000 | 821.366000 | ... | 806.811000 | 124.118000 | 1350.690000 | 237.649000 | 1885.860000 | 2283.400000 | 1967.630000 | 480.043000 | True | Frame_Shift_Del |
5 rows × 554 columns
Basic Exploration¶
import sklearn
from sklearn.decomposition import PCA
from copy import deepcopy
df_full = deepcopy(df)
df = df.drop(columns=['is_true', 'mutation', 'Variant_Classification'])
df.head()
| ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | AHDC1..ENSG00027245 | ... | ZBTB38..ENSG000253461 | ZBTB7C..ENSG000201501 | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 376.831000 | 1358.86000 | 2471.580000 | 143602.00000 | 159.674000 | 63.136500 | 946.639000 | 626.477000 | 344.195000 | 435.438000 | ... | 1820.37000 | 264.358000 | 323.344000 | 75.356400 | 8558.040000 | 43.991900 | 1783.300000 | 5320.570000 | 1018.330000 | 821.181000 |
| 1 | 198.244448 | 5367.62179 | 2528.570328 | 77726.97678 | 19.656121 | 2.579692 | 2130.976296 | 732.991931 | 386.605718 | 1185.830576 | ... | 527.55836 | 16.858995 | 228.638412 | 322.247574 | 6446.509718 | 36.542642 | 3207.438557 | 3213.116903 | 1688.261865 | 1149.407697 |
| 2 | 117.516000 | 1936.34000 | 14533.700000 | 185841.00000 | 95.490700 | 191.866000 | 766.578000 | 256.410000 | 239.611000 | 1976.130000 | ... | 1439.43000 | 164.456000 | 230.672000 | 121.132000 | 12726.800000 | 74.270600 | 2496.910000 | 4005.300000 | 923.961000 | 391.689000 |
| 3 | 60.747000 | 5667.60000 | 3560.420000 | 107645.00000 | 86.834700 | 1047.620000 | 698.413000 | 186.741000 | 262.372000 | 738.562000 | ... | 2136.32000 | 414.566000 | 638.609000 | 343.604000 | 8024.280000 | 78.431400 | 3746.030000 | 2692.810000 | 1168.070000 | 670.402000 |
| 4 | 327.477000 | 1096.61000 | 3430.480000 | 64166.60000 | 51.837300 | 9.491300 | 706.010000 | 1617.540000 | 821.366000 | 1041.860000 | ... | 1168.53000 | 183.986000 | 806.811000 | 124.118000 | 1350.690000 | 237.649000 | 1885.860000 | 2283.400000 | 1967.630000 | 480.043000 |
5 rows × 551 columns
df.describe()
| ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | AHDC1..ENSG00027245 | ... | ZBTB38..ENSG000253461 | ZBTB7C..ENSG000201501 | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | ... | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 |
| mean | 185.589301 | 3954.889567 | 7200.974864 | 114439.524334 | 207.436478 | 147.286450 | 825.070216 | 592.848118 | 442.641521 | 1094.332390 | ... | 1777.897799 | 534.550961 | 412.665719 | 255.096842 | 7621.347871 | 126.568887 | 3142.460706 | 2919.691552 | 1307.240249 | 661.287621 |
| std | 148.231968 | 3147.431047 | 17236.982170 | 48604.774763 | 241.214042 | 262.220737 | 560.475285 | 350.756481 | 239.295632 | 749.287419 | ... | 1141.164258 | 740.663305 | 239.702838 | 231.295388 | 4516.262873 | 109.098272 | 2018.346330 | 1489.464482 | 798.890198 | 325.307174 |
| min | 8.051500 | 29.764800 | 30.048173 | 24218.900000 | 0.000000 | -0.210767 | 20.852200 | 58.117025 | 8.992800 | 57.971000 | ... | 42.885000 | -0.319802 | 33.801400 | 0.000000 | 278.652000 | 0.399949 | 96.363800 | 339.315000 | 64.507200 | 81.250000 |
| 25% | 103.270500 | 1815.710000 | 1221.666028 | 79320.416245 | 39.071550 | 15.592600 | 519.444000 | 364.062373 | 276.480676 | 553.878500 | ... | 1040.665390 | 83.612450 | 259.671500 | 117.194500 | 4516.985000 | 56.959424 | 1839.472102 | 1958.070000 | 791.811500 | 459.930000 |
| 50% | 151.606000 | 3167.505717 | 2886.380000 | 105861.297800 | 123.480903 | 48.509415 | 697.903000 | 523.519000 | 400.155965 | 906.412000 | ... | 1546.240000 | 263.096000 | 365.717000 | 189.090000 | 6659.250000 | 99.149100 | 2705.100000 | 2640.281167 | 1107.065208 | 592.370060 |
| 75% | 224.627000 | 5211.564607 | 6516.566190 | 139285.500000 | 298.676000 | 154.258000 | 965.818000 | 734.935000 | 559.460000 | 1452.640926 | ... | 2216.760000 | 662.011282 | 504.985500 | 304.513576 | 9518.812363 | 163.171022 | 3862.868994 | 3476.665000 | 1606.385000 | 773.011315 |
| max | 3918.930000 | 34507.870810 | 252607.000000 | 478521.000000 | 3115.940000 | 4181.060000 | 11942.500000 | 6406.190000 | 2910.670000 | 8566.567356 | ... | 11829.000000 | 8530.350000 | 3955.320000 | 4229.350000 | 37363.300000 | 1614.630000 | 26002.200000 | 22667.957440 | 12483.000000 | 4091.275370 |
8 rows × 551 columns
Cumulative distribution function of a few genes:
L = 2
fig, axs = plt.subplots(L, L)
for i in range(L):
for j in range(L):
gene = df.columns[i*L + j]
gene_data = df[gene].sort_values()
gene_data = gene_data.reset_index(drop=True)
ecdf = sm.distributions.ECDF(gene_data)
axs[i, j].step(ecdf.x, ecdf.y)
plt.show()
We notice that the gene reads span mulitple orders of magnitude. There are some outliers, as can be seen from the quantiles and from the curves above.
We visualize the distribution of a few cells and genes to get a feeling for the data, using violin plots.
L = 2
fig, axs = plt.subplots(L, L, figsize=(10, 10))
for i in range(L):
for j in range(L):
sns.violinplot(x=df.iloc[i * L + j, :], ax=axs[i, j])
axs[i, j].set_title('Gene Expression profile of cell {}'.format(i * L + j))
df_small = df.iloc[:10, :].T
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of a sample of gene expression profiles")
plt.show()
L = 2
fig, axs = plt.subplots(L, L, figsize=(10, 10))
for i in range(L):
for j in range(L):
sns.violinplot(x=df.iloc[i * L + j, :], ax=axs[i, j])
axs[i, j].set_title('Expression of gene {} across cells'.format(i * L + j))
df_small = df.iloc[:, :10]
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of the expression levels across cells for a sample of genes")
plt.show()
We check the sparsity level of the gene expression profiles. With some single cell sequencing datasets, it happens that the examples are extremely sparse. This is important to take into account, both for computational efficiency and to correctly interpret statistics. In our case, however, it seems that the sparsity level is small:
print("sparsity level:", (df == 0).sum().sum() / df.size)
sparsity level: 0.02183073369763143
As noticed above, the distributions cover a very long range, spanning multiple orders of magnitude. This might be an indication that the data is more naturally handled in log scale. We proceed to log the data.
df = df + 1
df = df.apply(np.log)
df.head()
| ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | AHDC1..ENSG00027245 | ... | ZBTB38..ENSG000253461 | ZBTB7C..ENSG000201501 | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 5.934447 | 7.215137 | 7.813017 | 11.874808 | 5.079377 | 4.161014 | 6.853974 | 6.441707 | 5.844109 | 6.078646 | ... | 7.507344 | 5.581080 | 5.781805 | 4.335412 | 9.054743 | 3.806482 | 7.486781 | 8.579524 | 6.926901 | 6.711961 |
| 1 | 5.294532 | 8.588327 | 7.835805 | 11.260971 | 3.028012 | 1.275277 | 7.664805 | 6.598498 | 5.959989 | 7.079042 | ... | 6.270153 | 2.882507 | 5.436506 | 5.778419 | 8.771449 | 3.625477 | 8.073540 | 8.075308 | 7.432047 | 7.047872 |
| 2 | 4.775048 | 7.569071 | 9.584294 | 12.132652 | 4.569447 | 5.261996 | 6.643240 | 5.550670 | 5.483182 | 7.589402 | ... | 7.272697 | 5.108705 | 5.445323 | 4.805102 | 9.451544 | 4.321090 | 7.823210 | 8.295623 | 6.829752 | 5.973018 |
| 3 | 4.123045 | 8.642697 | 8.177915 | 11.586603 | 4.475457 | 6.955230 | 6.550241 | 5.235063 | 5.573567 | 6.606058 | ... | 7.667308 | 6.029641 | 6.460857 | 5.842396 | 8.990352 | 4.374894 | 8.228719 | 7.898712 | 7.063964 | 6.509368 |
| 4 | 5.794467 | 7.000890 | 8.140747 | 11.069254 | 3.967217 | 2.350546 | 6.561045 | 7.389280 | 6.712186 | 6.949722 | ... | 7.064357 | 5.220280 | 6.694328 | 4.829257 | 7.209111 | 5.474994 | 7.542669 | 7.733859 | 7.585093 | 6.175957 |
5 rows × 551 columns
df.describe()
| ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | AHDC1..ENSG00027245 | ... | ZBTB38..ENSG000253461 | ZBTB7C..ENSG000201501 | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | ... | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 | 4211.000000 |
| mean | 5.027813 | 7.974901 | 7.969045 | 11.565688 | 4.650616 | 3.932039 | 6.578123 | 6.240796 | 5.955885 | 6.785079 | ... | 7.302806 | 5.368857 | 5.888605 | 5.255410 | 8.778089 | 4.571372 | 7.876275 | 7.872595 | 7.024809 | 6.400346 |
| std | 0.626356 | 0.858968 | 1.283715 | 0.404949 | 1.313455 | 1.525478 | 0.506659 | 0.544390 | 0.548099 | 0.671384 | ... | 0.623959 | 1.578166 | 0.523837 | 0.762479 | 0.580441 | 0.760368 | 0.606623 | 0.459292 | 0.548695 | 0.428670 |
| min | 2.202930 | 3.426371 | 3.435540 | 10.094930 | 0.000000 | -0.236694 | 3.084302 | 4.079519 | 2.301865 | 4.077046 | ... | 3.781573 | -0.385372 | 3.549658 | 0.000000 | 5.633546 | 0.336436 | 4.578454 | 5.829872 | 4.182160 | 4.409763 |
| 25% | 4.646988 | 7.504782 | 7.108789 | 11.281263 | 3.690667 | 2.808957 | 6.254682 | 5.900068 | 5.625751 | 6.318749 | ... | 6.948576 | 4.438081 | 5.563261 | 4.772332 | 8.415821 | 4.059743 | 7.517777 | 7.580225 | 6.675585 | 6.133246 |
| 50% | 5.027859 | 8.061015 | 7.968105 | 11.569894 | 4.824152 | 3.902163 | 6.549512 | 6.262482 | 5.994350 | 6.810597 | ... | 7.344228 | 5.576313 | 5.904590 | 5.247498 | 8.803912 | 4.606660 | 7.903264 | 7.879019 | 7.010371 | 6.385818 |
| 75% | 5.418883 | 8.558827 | 8.782255 | 11.844288 | 5.702702 | 5.045088 | 6.874010 | 6.601142 | 6.328758 | 7.281827 | ... | 7.704253 | 6.496792 | 6.226508 | 5.721994 | 9.161130 | 5.100909 | 8.259424 | 8.154116 | 7.382364 | 6.651586 |
| max | 8.273829 | 10.448972 | 12.439594 | 13.078457 | 8.044607 | 8.338559 | 9.387942 | 8.765176 | 7.976482 | 9.055739 | ... | 9.378394 | 9.051503 | 8.283070 | 8.350040 | 10.528471 | 7.387480 | 10.165975 | 10.028752 | 9.432203 | 8.316856 |
8 rows × 551 columns
Now let's visualize the new data after applying the logarithm:
df_small = df.iloc[:10, :].T
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of a sample of gene expression profiles")
plt.show()
df_small = df.iloc[:, :10]
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of the expression levels across cells for a sample of genes")
plt.show()
To prepare for unsupervised and supervised machine learning algorithms, we also normalize the gene expressions, gene by gene. In fact, as we see from the graph above, different genes have different means and standard deviations, so algorithms (e.g., kmeans or logistic regression) risk focusing on some more than others, since they exhibit larger variation.
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
array = scaler.fit_transform(df)
df = pd.DataFrame(array, columns=df.columns)
df.head()
| ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | AHDC1..ENSG00027245 | ... | ZBTB38..ENSG000253461 | ZBTB7C..ENSG000201501 | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.447646 | -0.884613 | -0.121559 | 0.763446 | 0.326477 | 0.150118 | 0.544515 | 0.369101 | -0.203957 | -1.052329 | ... | 0.327846 | 0.134490 | -0.203906 | -1.206731 | 0.476685 | -1.006067 | -0.642145 | 1.539354 | -0.178460 | 0.727020 |
| 1 | 0.425878 | 0.714228 | -0.103805 | -0.752572 | -1.235518 | -1.741801 | 2.145054 | 0.657148 | 0.007488 | 0.437898 | ... | -1.655197 | -1.575655 | -0.863156 | 0.686013 | -0.011440 | -1.244144 | 0.325224 | 0.441411 | 0.742282 | 1.510726 |
| 2 | -0.403597 | -0.472518 | 1.258411 | 1.400254 | -0.061806 | 0.871933 | 0.128538 | -1.267854 | -0.862544 | 1.198149 | ... | -0.048260 | -0.164864 | -0.846323 | -0.590653 | 1.160385 | -0.329199 | -0.087487 | 0.921154 | -0.355537 | -0.996989 |
| 3 | -1.444666 | 0.777533 | 0.162726 | 0.051656 | -0.133374 | 1.982035 | -0.055037 | -1.847666 | -0.697616 | -0.266677 | ... | 0.584245 | 0.418754 | 1.092553 | 0.769930 | 0.365736 | -0.258430 | 0.581062 | 0.056869 | 0.071368 | 0.254357 |
| 4 | 1.224136 | -1.134066 | 0.133770 | -1.226063 | -0.520368 | -1.036843 | -0.033712 | 2.109920 | 1.380025 | 0.245258 | ... | -0.382200 | -0.094157 | 1.538300 | -0.558970 | -2.703400 | 1.188542 | -0.550005 | -0.302103 | 1.021242 | -0.523517 |
5 rows × 551 columns
df.describe()
| ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | AHDC1..ENSG00027245 | ... | ZBTB38..ENSG000253461 | ZBTB7C..ENSG000201501 | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | ... | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 | 4.211000e+03 |
| mean | 1.554049e-15 | -4.758325e-16 | -3.678421e-16 | -2.682885e-16 | -5.239219e-16 | -3.610927e-16 | 1.906705e-16 | -5.264529e-16 | 2.902241e-16 | 1.149085e-15 | ... | 4.302740e-17 | 1.147397e-16 | 4.049638e-16 | -9.449155e-17 | 3.610927e-16 | -9.837246e-16 | 6.665029e-16 | 1.737126e-15 | -1.404718e-16 | 1.181144e-16 |
| std | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | ... | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 | 1.000119e+00 |
| min | -4.510562e+00 | -5.295972e+00 | -3.531972e+00 | -3.632389e+00 | -3.541172e+00 | -2.733064e+00 | -6.896626e+00 | -3.970557e+00 | -6.667505e+00 | -4.033989e+00 | ... | -5.644040e+00 | -3.646584e+00 | -4.465560e+00 | -6.893345e+00 | -5.418151e+00 | -5.570250e+00 | -5.437008e+00 | -4.448082e+00 | -5.181365e+00 | -4.644180e+00 |
| 25% | -6.080725e-01 | -5.473712e-01 | -6.702104e-01 | -7.024539e-01 | -7.309452e-01 | -7.363043e-01 | -6.384559e-01 | -6.259623e-01 | -6.023963e-01 | -6.946628e-01 | ... | -5.677806e-01 | -5.898535e-01 | -6.211531e-01 | -6.336376e-01 | -6.241986e-01 | -6.729505e-01 | -5.910434e-01 | -6.366437e-01 | -6.365389e-01 | -6.231641e-01 |
| 50% | 7.394250e-05 | 1.002653e-01 | -7.327976e-04 | 1.038946e-02 | 1.321378e-01 | -1.958719e-02 | -5.647691e-02 | 3.984025e-02 | 7.018777e-02 | 3.801168e-02 | ... | 6.639363e-02 | 1.314692e-01 | 3.051905e-02 | -1.037819e-02 | 4.449485e-02 | 4.641426e-02 | 4.449540e-02 | 1.398821e-02 | -2.631787e-02 | -3.389436e-02 |
| 75% | 6.244316e-01 | 6.798808e-01 | 6.335571e-01 | 6.880706e-01 | 8.011018e-01 | 7.297264e-01 | 5.840663e-01 | 6.620046e-01 | 6.803826e-01 | 7.399737e-01 | ... | 6.434628e-01 | 7.147972e-01 | 6.451297e-01 | 6.120030e-01 | 6.599933e-01 | 6.965039e-01 | 6.316853e-01 | 6.130188e-01 | 6.517230e-01 | 5.861629e-01 |
| max | 5.182997e+00 | 2.880625e+00 | 3.482923e+00 | 3.736147e+00 | 2.584326e+00 | 2.888960e+00 | 5.546441e+00 | 4.637628e+00 | 3.686992e+00 | 3.382461e+00 | ... | 3.326874e+00 | 2.333775e+00 | 4.571553e+00 | 4.059123e+00 | 3.015966e+00 | 3.704052e+00 | 3.774953e+00 | 4.695084e+00 | 4.388014e+00 | 4.471364e+00 |
8 rows × 551 columns
Data visualization after normalization:
df_small = df.iloc[:10, :].T
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of a sample of gene expression profiles")
plt.show()
df_small = df.iloc[:, :10]
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of the expression levels across cells for a sample of genes")
plt.show()
Dimensionality Reduction¶
In order to visualize the data, and to check whether any interesting patterns are visible, we use two dimensionality reduction techniques.
First, we resort to an explainable and simple linear technique, PCA, to find a set of few orthogonal directions that explain most of the variance of the data. We can also inspect the principal components, which represent 'metagenes' along which our dataset exhibits large variance, and may turn out to be useful for classification as well.
Then, we employ a more powerful nonlinear dimensionality reduction technique, tSNE, which was designed specifically for data visualization, to get a more accurate depiction of the local structure of our data.
PCA¶
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
pca.fit(df)
array_pca = pca.transform(df)
df_pca = pd.DataFrame(array_pca, columns=['PC1', 'PC2'])
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', data=df_pca, hue=df_full['is_true'], alpha=0.5)
plt.title("PCA of the gene expression profiles")
plt.show()
array_reconstructed = pca.inverse_transform(array_pca)
mse = ((array - array_reconstructed) ** 2).mean()
print("MSE of the reconstruction:", mse)
MSE of the reconstruction: 0.75991853023862
explained_variance = pca.explained_variance_ratio_.sum()
print("Explained variance:", explained_variance)
Explained variance: 0.24008146976137446
With only two directions, only 24% of the variance is explained. The reconstruction error is high (0.76, with data having features with standard deviation 1 and mean 0). We try to increase the number of principal components, and monitor how the explained variance increases.
pca = PCA(n_components=551)
pca.fit(df)
array_pca = pca.transform(df)
var = pca.explained_variance_ratio_[0:20]
labels = ["PC"+str(i+1) for i in range(20)]
plt.figure(figsize=(16,4))
plt.bar(labels, var)
plt.xlabel('Principal Components')
plt.ylabel('Proportion of Variance Explained')
plt.show()
cum_var = np.cumsum(pca.explained_variance_ratio_)
plt.plot(cum_var)
plt.hlines(0.72, 0, 551, colors='r', linestyles='dashed', label='0.72 explained variance')
plt.xlabel('Number of Principal Components')
plt.ylabel('Cumulative Proportion of Variance Explained')
plt.legend()
plt.show()
np.argmax(cum_var > 0.72) # hacky
49
As is often the case, most of the variance is explained using relatively few components. There seems to be an ellbow around 49 components. These plots will be useful in the classification part, using PCA to extract features from data.
tSNE¶
As the results of tSNE are sensible to the choice of hyperparameters, and especially of perplexity, we try a few configurations in a grid and go for the hyperparameters that yield the most stable results.
from sklearn.manifold import TSNE
L = 3
fig, axs = plt.subplots(L, L, figsize=(15, 15))
perplexities = [5, 10, 15, 20, 25, 30, 35, 40, 45]
for i in range(L):
for j in range(L):
perplexity = perplexities[i * L + j]
tsne = TSNE(n_components=2, perplexity=perplexity, max_iter=1000)
array_tsne = tsne.fit_transform(df)
df_tsne = pd.DataFrame(array_tsne, columns=['tSNE1', 'tSNE2'])
sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, alpha=0.5, ax=axs[i, j])
axs[i, j].set_title('Perplexity = {}'.format(perplexity))
Results look crispier and more stable for values of the perplexity in the range 15-35. We go for 25. Notice that some clusters seem to emerge. We will investigate further later, using clustering techniques, to see whether this is an artifact of dimensionality reduction or whether, even in the higher dimensional space, clusters emerge naturally. Now, let's see whether any structure emerges by coloring points according to their labels - first functional vs non functional and then mutation type.
tsne = TSNE(n_components=2, perplexity=25, max_iter=1000)
array_tsne = tsne.fit_transform(df)
df_tsne = pd.DataFrame(array_tsne, columns=['tSNE1', 'tSNE2'])
plt.figure(figsize=(10, 10))
sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, hue=df_full['is_true'], alpha=0.5)
plt.title("tSNE of the gene expression profiles")
plt.show()
#plt.figure(figsize=(10, 10))
sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, hue=df_full['mutation'], alpha=0.5, palette='Set1')
plt.title("tSNE of the gene expression profiles")
plt.show()
Clustering¶
gene_columns = [col for col in df.columns if col not in ['is_true', 'mutation', 'Variant_Classification']]
X = df[gene_columns].to_numpy()
X.shape
(4211, 551)
# for visualization later
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
tsne = TSNE(n_components=2, perplexity=25, max_iter=1000)
array_tsne = tsne.fit_transform(df[gene_columns])
df_tsne = pd.DataFrame(array_tsne, columns=['tSNE1', 'tSNE2'])
We want to perform clustering on our data. Gene expression profiles live in a 500-dimensional space. Since clustering is entirely based on a notion of distance in the data space, it's important that the used metric be meaningful. However, because of the well known phenomenon of the curse of dimensionality, distances tend to lose meaning as the dimension of the ambient space increases (there are counter intuitive phenomena like concentration of volume on boundaries, quasi-orthogonality of random vectors and many other well known examples happening). For this reason, we carry out clustering on dimensionally reduced data, using PCA. For the number of components, we use the one we highlighted as ellbow, since it strikes a good trade off between information retain and dimensionality.
pca = PCA(n_components=50)
pca.fit(X)
X_pca = pca.transform(X)
To choose the number of clusters, which must be provided as input to Kmeans, we use two indicators: Silhouette score and inertia.
from sklearn.cluster import KMeans
inertia = []
for k in range(1, 51):
kmeans = KMeans(n_clusters=k)
kmeans.fit(X_pca)
inertia.append(kmeans.inertia_)
plt.plot(range(1, 51), inertia)
[<matplotlib.lines.Line2D at 0x2a6348b50>]
# silhouette score
from sklearn.metrics import silhouette_score
silhouette = []
for k in range(2, 51):
kmeans = KMeans(n_clusters=k)
kmeans.fit(X_pca)
silhouette.append(silhouette_score(X, kmeans.labels_))
plt.plot(range(2, 51), silhouette)
[<matplotlib.lines.Line2D at 0x2a637f8d0>]
2 + np.argmax(silhouette)
3
The number of clusters maximizing the Silhouette score is k = 6. The inertia gives little indication, as no clear elbow is visible. We go with k = 6.
k = 6
kmeans = KMeans(n_clusters=k)
kmeans.fit(X_pca)
df_tsne['cluster'] = kmeans.labels_
The Silhouette score is low so the quality of the clusters is not optimal. Since the value is near 0, we expect to have poorly clustered points
plt.figure(figsize=(10, 10))
sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, hue=df_tsne['cluster'], alpha=0.5, palette='Set1')
plt.title("tSNE of the gene expression profiles")
plt.show()
Correlation for feature selection¶
To find the most promising genes to consider for classification, we calculate the correlation of the labels with each single gene. We use the point biserial correlation coefficient, since we are dealing with a continuous and a binary variable. This is a special case of the Pearson correlation coefficient.
csv_file = 'data/TCGA_labels.csv'
df = pd.read_csv(csv_file)
df.head()
| Variant_Classification | ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | ... | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | is_true | mutation | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... | 376.831000 | 1358.86000 | 2471.580000 | 143602.00000 | 159.674000 | 63.136500 | 946.639000 | 626.477000 | 344.195000 | ... | 323.344000 | 75.356400 | 8558.040000 | 43.991900 | 1783.300000 | 5320.570000 | 1018.330000 | 821.181000 | True | Frame_Shift_Ins |
| 1 | A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... | 198.244448 | 5367.62179 | 2528.570328 | 77726.97678 | 19.656121 | 2.579692 | 2130.976296 | 732.991931 | 386.605718 | ... | 228.638412 | 322.247574 | 6446.509718 | 36.542642 | 3207.438557 | 3213.116903 | 1688.261865 | 1149.407697 | True | In_Frame_Del |
| 2 | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | 117.516000 | 1936.34000 | 14533.700000 | 185841.00000 | 95.490700 | 191.866000 | 766.578000 | 256.410000 | 239.611000 | ... | 230.672000 | 121.132000 | 12726.800000 | 74.270600 | 2496.910000 | 4005.300000 | 923.961000 | 391.689000 | True | Frame_Shift_Del |
| 3 | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | 60.747000 | 5667.60000 | 3560.420000 | 107645.00000 | 86.834700 | 1047.620000 | 698.413000 | 186.741000 | 262.372000 | ... | 638.609000 | 343.604000 | 8024.280000 | 78.431400 | 3746.030000 | 2692.810000 | 1168.070000 | 670.402000 | True | Frame_Shift_Del |
| 4 | A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... | 327.477000 | 1096.61000 | 3430.480000 | 64166.60000 | 51.837300 | 9.491300 | 706.010000 | 1617.540000 | 821.366000 | ... | 806.811000 | 124.118000 | 1350.690000 | 237.649000 | 1885.860000 | 2283.400000 | 1967.630000 | 480.043000 | True | Frame_Shift_Del |
5 rows × 554 columns
def log_and_normalize(df: pd.DataFrame) -> pd.DataFrame:
# all columns but 'is_true', 'mutation', and 'Variant_Classification'
features = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
# log-transform and normalize the features
features = features.apply(lambda x: np.log(1 + x))
features = (features - features.mean()) / features.std()
# add back the non-numeric columns
features = pd.concat(
[features, df[["mutation", "Variant_Classification", "is_true"]]], axis=1
)
return features
df = log_and_normalize(df)
df.head()
| ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | AHDC1..ENSG00027245 | ... | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | mutation | Variant_Classification | is_true | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.447474 | -0.884508 | -0.121544 | 0.763355 | 0.326438 | 0.150100 | 0.544450 | 0.369058 | -0.203933 | -1.052204 | ... | -1.206587 | 0.476628 | -1.005947 | -0.642069 | 1.539171 | -0.178439 | 0.726934 | Frame_Shift_Ins | A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... | True |
| 1 | 0.425827 | 0.714143 | -0.103793 | -0.752483 | -1.235371 | -1.741594 | 2.144800 | 0.657070 | 0.007487 | 0.437846 | ... | 0.685931 | -0.011439 | -1.243997 | 0.325185 | 0.441359 | 0.742193 | 1.510547 | In_Frame_Del | A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... | True |
| 2 | -0.403549 | -0.472462 | 1.258262 | 1.400088 | -0.061798 | 0.871830 | 0.128522 | -1.267703 | -0.862442 | 1.198007 | ... | -0.590583 | 1.160248 | -0.329160 | -0.087477 | 0.921045 | -0.355494 | -0.996870 | Frame_Shift_Del | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | True |
| 3 | -1.444494 | 0.777441 | 0.162707 | 0.051650 | -0.133358 | 1.981800 | -0.055030 | -1.847446 | -0.697533 | -0.266645 | ... | 0.769839 | 0.365693 | -0.258399 | 0.580993 | 0.056862 | 0.071359 | 0.254326 | Frame_Shift_Del | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | True |
| 4 | 1.223990 | -1.133931 | 0.133754 | -1.225917 | -0.520306 | -1.036720 | -0.033708 | 2.109670 | 1.379861 | 0.245229 | ... | -0.558904 | -2.703079 | 1.188400 | -0.549939 | -0.302067 | 1.021121 | -0.523455 | Frame_Shift_Del | A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... | True |
5 rows × 554 columns
plt.figure(figsize=(10, 6))
sns.violinplot(x='is_true', y=df.columns[0], data=df)
plt.xlabel('Functional')
plt.title("Distribution of the first gene expression level for functional and dysfunctional cells")
plt.show()
Correlations of genes with labels¶
Functional vs Dysfunctional¶
from scipy import stats
correlations = {}
for gene in df.columns:
if gene in ['is_true', 'mutation', 'Variant_Classification']:
continue
a = df['is_true'].to_numpy().astype(np.float64)
b = df[gene].to_numpy().astype(np.float64)
corr, pval = stats.pointbiserialr(a, b)
correlations[gene] = corr
corrs = list(correlations.values())
corrs = np.array(corrs)
np.abs(corrs).mean(), np.abs(corrs).std(), np.abs(corrs).max()
(0.027919123954705125, 0.01954913018439728, 0.08834416504312281)
Genes are not very correlated with the labels. Most of them would essentially provide noise for the classifier to see through and it's probably going to be better to remove them. Let's visualize the correlations as a curve.
corrs.sort()
plt.plot(corrs)
[<matplotlib.lines.Line2D at 0x2a7bb9ad0>]
# save genes in order of absolute correlation for later use
import json
sorted_correlations = sorted(correlations.items(), key=lambda x: np.abs(x[1]), reverse=True)
good_genes = [corr[0] for corr in sorted_correlations]
with open('good_genes_tf.txt', 'w') as f:
json.dump(good_genes, f)
Mutation type¶
from scipy import stats
correlations = {}
for gene in df.columns:
if gene in ['is_true', 'mutation', 'Variant_Classification']:
continue
a = (df['mutation'] == 'Missense_Mutation').to_numpy().astype(np.float64)
b = df[gene].to_numpy().astype(np.float64)
corr, pval = stats.pointbiserialr(a, b)
correlations[gene] = corr
corrs = list(correlations.values())
corrs = np.array(corrs)
np.abs(corrs).mean()
np.abs(corrs).std()
np.abs(corrs).max()
0.09154977907792622
corrs.sort()
plt.plot(corrs)
[<matplotlib.lines.Line2D at 0x2ab1a1ad0>]
import json
sorted_correlations = sorted(correlations.items(), key=lambda x: np.abs(x[1]), reverse=True)
good_genes = [corr[0] for corr in sorted_correlations]
with open('good_genes_missense.txt', 'w') as f:
json.dump(good_genes, f)
Classification¶
from copy import deepcopy
df = deepcopy(df_full)
df = log_and_normalize(df)
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = df_full['mutation']
We check the proportions of each class
y.value_counts() / len(y)
mutation Missense_Mutation 0.642840 Nonsense_Mutation 0.129898 Frame_Shift_Del 0.093090 Splice_Site 0.066018 Frame_Shift_Ins 0.028734 In_Frame_Del 0.018048 Splice_Region 0.011874 Fusion_ 0.005937 In_Frame_Ins 0.003325 Translation_Start_Site 0.000237 Name: count, dtype: float64
Since non-missense mutations aren't well represented, we group them under a single label.
y = y.apply(lambda x: 1 if x == 'Missense_Mutation' else 0)
y.value_counts() / len(y)
mutation 1 0.64284 0 0.35716 Name: count, dtype: float64
We split the dataset into three subsets: train set to train our models, validation set to tune the hyperparameters and test set to assess the results.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)
We now proceed to fit classifiers to our data. We focus here and comment on the task of predicting missense vs non-missense mutation on the TCGA dataset: the results for the other 3 data-task combinations are similar and can be found in the Appendix.
To solve this task, we refrain from using deep learning and instead resort to classical supervised learning methods that are simpler and less prone to overfitting. This choice is motivated by the scarcity of available labeled samples.
We use k-fold cross validation to obtain statistically robust estimates of the validation error of different hyperparameter combinations.
Random forest¶
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
def hyperparameter_search(
model,
X_train,
y_train,
X_val,
y_val,
param_grid,
search_type="grid",
n_iter=10,
scoring="accuracy",
cv=5,
verbose=2,
):
"""
Perform hyperparameter search using grid search or random search.
Parameters:
- model: The machine learning model to tune.
- X_train: Training feature set.
- y_train: Training target variable.
- X_val: Validation feature set.
- y_val: Validation target variable.
- param_grid: Dictionary of hyperparameters to search over.
- search_type: 'grid' for GridSearchCV or 'random' for RandomizedSearchCV.
- n_iter: Number of iterations for RandomizedSearchCV (ignored for GridSearchCV).
- scoring: Scoring metric to use for evaluation.
- cv: Number of cross-validation folds.
- verbose: Verbosity level for the search.
Returns:
- best_model: The model with the best hyperparameters.
- best_params: The best hyperparameters found during the search.
- best_score: The best score achieved with the best hyperparameters.
- all_results: DataFrame with hyperparameters and corresponding validation scores.
"""
if search_type == "grid":
search = GridSearchCV(
estimator=model,
param_grid=param_grid,
scoring=scoring,
cv=cv,
return_train_score=True,
verbose=verbose,
n_jobs=4
)
elif search_type == "random":
search = RandomizedSearchCV(
estimator=model,
param_distributions=param_grid,
scoring=scoring,
cv=cv,
n_iter=n_iter,
random_state=0,
return_train_score=True,
verbose=verbose,
n_jobs=4
)
else:
raise ValueError("search_type must be either 'grid' or 'random'")
# can i use multi core?
search.fit(X_train, y_train)
best_model = search.best_estimator_
best_params = search.best_params_
best_score = search.best_score_
val_predictions = best_model.predict(X_val)
val_score = accuracy_score(y_val, val_predictions)
print(f"Validation Score with best hyperparameters: {val_score}")
# Collect all results
results = search.cv_results_
all_results = pd.DataFrame(results)
return best_model, best_params, best_score, all_results
model = RandomForestClassifier()
param_grid = {
'n_estimators': [100, 200],
'max_depth': [3, 5, 7, 10, 15, 20],
'bootstrap': [True, False],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
param_grid, search_type='grid', cv=3,
verbose=2)
Fitting 3 folds for each of 24 candidates, totalling 72 fits [CV] END ......bootstrap=True, max_depth=3, n_estimators=100; total time= 2.1s [CV] END ......bootstrap=True, max_depth=3, n_estimators=100; total time= 2.1s [CV] END ......bootstrap=True, max_depth=3, n_estimators=100; total time= 2.2s [CV] END ......bootstrap=True, max_depth=3, n_estimators=200; total time= 4.9s [CV] END ......bootstrap=True, max_depth=5, n_estimators=100; total time= 4.3s [CV] END ......bootstrap=True, max_depth=3, n_estimators=200; total time= 5.1s [CV] END ......bootstrap=True, max_depth=3, n_estimators=200; total time= 5.2s [CV] END ......bootstrap=True, max_depth=5, n_estimators=100; total time= 3.4s [CV] END ......bootstrap=True, max_depth=5, n_estimators=100; total time= 2.9s [CV] END ......bootstrap=True, max_depth=5, n_estimators=200; total time= 5.5s [CV] END ......bootstrap=True, max_depth=5, n_estimators=200; total time= 5.5s [CV] END ......bootstrap=True, max_depth=7, n_estimators=100; total time= 3.6s [CV] END ......bootstrap=True, max_depth=5, n_estimators=200; total time= 5.5s [CV] END ......bootstrap=True, max_depth=7, n_estimators=100; total time= 3.6s [CV] END ......bootstrap=True, max_depth=7, n_estimators=100; total time= 3.5s [CV] END ......bootstrap=True, max_depth=7, n_estimators=200; total time= 7.1s [CV] END ......bootstrap=True, max_depth=7, n_estimators=200; total time= 7.1s [CV] END .....bootstrap=True, max_depth=10, n_estimators=100; total time= 4.7s [CV] END ......bootstrap=True, max_depth=7, n_estimators=200; total time= 7.4s [CV] END .....bootstrap=True, max_depth=10, n_estimators=100; total time= 5.0s [CV] END .....bootstrap=True, max_depth=10, n_estimators=100; total time= 5.4s [CV] END .....bootstrap=True, max_depth=10, n_estimators=200; total time= 10.6s [CV] END .....bootstrap=True, max_depth=15, n_estimators=100; total time= 6.4s [CV] END .....bootstrap=True, max_depth=10, n_estimators=200; total time= 10.4s [CV] END .....bootstrap=True, max_depth=10, n_estimators=200; total time= 10.3s [CV] END .....bootstrap=True, max_depth=15, n_estimators=100; total time= 6.1s [CV] END .....bootstrap=True, max_depth=15, n_estimators=100; total time= 6.0s [CV] END .....bootstrap=True, max_depth=20, n_estimators=100; total time= 7.2s [CV] END .....bootstrap=True, max_depth=15, n_estimators=200; total time= 12.8s [CV] END .....bootstrap=True, max_depth=15, n_estimators=200; total time= 12.6s [CV] END .....bootstrap=True, max_depth=15, n_estimators=200; total time= 13.0s [CV] END .....bootstrap=True, max_depth=20, n_estimators=100; total time= 6.4s [CV] END .....bootstrap=True, max_depth=20, n_estimators=100; total time= 6.6s [CV] END .....bootstrap=False, max_depth=3, n_estimators=100; total time= 2.7s [CV] END .....bootstrap=False, max_depth=3, n_estimators=100; total time= 2.7s [CV] END .....bootstrap=True, max_depth=20, n_estimators=200; total time= 13.0s [CV] END .....bootstrap=False, max_depth=3, n_estimators=100; total time= 2.9s [CV] END .....bootstrap=True, max_depth=20, n_estimators=200; total time= 13.8s [CV] END .....bootstrap=True, max_depth=20, n_estimators=200; total time= 13.9s [CV] END .....bootstrap=False, max_depth=3, n_estimators=200; total time= 6.2s [CV] END .....bootstrap=False, max_depth=3, n_estimators=200; total time= 6.1s [CV] END .....bootstrap=False, max_depth=3, n_estimators=200; total time= 5.6s [CV] END .....bootstrap=False, max_depth=5, n_estimators=100; total time= 4.4s [CV] END .....bootstrap=False, max_depth=5, n_estimators=100; total time= 4.4s [CV] END .....bootstrap=False, max_depth=5, n_estimators=100; total time= 4.4s [CV] END .....bootstrap=False, max_depth=7, n_estimators=100; total time= 5.7s [CV] END .....bootstrap=False, max_depth=5, n_estimators=200; total time= 8.5s [CV] END .....bootstrap=False, max_depth=5, n_estimators=200; total time= 8.5s [CV] END .....bootstrap=False, max_depth=5, n_estimators=200; total time= 8.7s [CV] END .....bootstrap=False, max_depth=7, n_estimators=100; total time= 5.9s [CV] END .....bootstrap=False, max_depth=7, n_estimators=100; total time= 5.9s [CV] END .....bootstrap=False, max_depth=7, n_estimators=200; total time= 12.0s [CV] END .....bootstrap=False, max_depth=7, n_estimators=200; total time= 12.0s [CV] END ....bootstrap=False, max_depth=10, n_estimators=100; total time= 8.3s [CV] END .....bootstrap=False, max_depth=7, n_estimators=200; total time= 12.6s [CV] END ....bootstrap=False, max_depth=10, n_estimators=100; total time= 8.8s [CV] END ....bootstrap=False, max_depth=10, n_estimators=100; total time= 8.7s [CV] END ....bootstrap=False, max_depth=10, n_estimators=200; total time= 15.9s [CV] END ....bootstrap=False, max_depth=15, n_estimators=100; total time= 9.4s [CV] END ....bootstrap=False, max_depth=10, n_estimators=200; total time= 15.2s [CV] END ....bootstrap=False, max_depth=10, n_estimators=200; total time= 15.2s [CV] END ....bootstrap=False, max_depth=15, n_estimators=100; total time= 9.3s [CV] END ....bootstrap=False, max_depth=15, n_estimators=100; total time= 9.5s [CV] END ....bootstrap=False, max_depth=20, n_estimators=100; total time= 10.1s [CV] END ....bootstrap=False, max_depth=15, n_estimators=200; total time= 18.5s [CV] END ....bootstrap=False, max_depth=15, n_estimators=200; total time= 18.4s [CV] END ....bootstrap=False, max_depth=15, n_estimators=200; total time= 18.7s [CV] END ....bootstrap=False, max_depth=20, n_estimators=100; total time= 10.1s [CV] END ....bootstrap=False, max_depth=20, n_estimators=100; total time= 10.3s [CV] END ....bootstrap=False, max_depth=20, n_estimators=200; total time= 17.9s [CV] END ....bootstrap=False, max_depth=20, n_estimators=200; total time= 16.9s [CV] END ....bootstrap=False, max_depth=20, n_estimators=200; total time= 16.3s Validation Score with best hyperparameters: 0.6484560570071259
best_score, best_params
(0.6374702448506858, {'bootstrap': True, 'max_depth': 3, 'n_estimators': 200})
all_results
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_bootstrap | param_max_depth | param_n_estimators | params | split0_test_score | split1_test_score | split2_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2.112955 | 0.005620 | 0.028860 | 0.003517 | True | 3 | 100 | {'bootstrap': True, 'max_depth': 3, 'n_estimat... | 0.637578 | 0.637578 | 0.636364 | 0.637173 | 0.000572 | 5 | 0.637416 | 0.637416 | 0.639359 | 0.638064 | 0.000916 |
| 1 | 5.028177 | 0.136659 | 0.034389 | 0.012648 | True | 3 | 200 | {'bootstrap': True, 'max_depth': 3, 'n_estimat... | 0.637578 | 0.637578 | 0.637255 | 0.637470 | 0.000152 | 1 | 0.637416 | 0.637416 | 0.638468 | 0.637767 | 0.000496 |
| 2 | 3.512773 | 0.584872 | 0.017173 | 0.003729 | True | 5 | 100 | {'bootstrap': True, 'max_depth': 5, 'n_estimat... | 0.635797 | 0.638468 | 0.637255 | 0.637173 | 0.001092 | 4 | 0.662361 | 0.674833 | 0.670525 | 0.669240 | 0.005172 |
| 3 | 5.515238 | 0.007261 | 0.023251 | 0.001754 | True | 5 | 200 | {'bootstrap': True, 'max_depth': 5, 'n_estimat... | 0.637578 | 0.637578 | 0.635472 | 0.636876 | 0.000993 | 6 | 0.659243 | 0.668597 | 0.676313 | 0.668051 | 0.006980 |
| 4 | 3.552134 | 0.035117 | 0.015851 | 0.000998 | True | 7 | 100 | {'bootstrap': True, 'max_depth': 7, 'n_estimat... | 0.634907 | 0.633126 | 0.630125 | 0.632719 | 0.001973 | 11 | 0.767483 | 0.787082 | 0.776046 | 0.776871 | 0.008023 |
| 5 | 7.148444 | 0.140806 | 0.027100 | 0.001456 | True | 7 | 200 | {'bootstrap': True, 'max_depth': 7, 'n_estimat... | 0.634907 | 0.637578 | 0.630125 | 0.634203 | 0.003083 | 9 | 0.787973 | 0.781292 | 0.766251 | 0.778505 | 0.009084 |
| 6 | 4.993407 | 0.269705 | 0.025283 | 0.005352 | True | 10 | 100 | {'bootstrap': True, 'max_depth': 10, 'n_estima... | 0.634907 | 0.619768 | 0.622103 | 0.625593 | 0.006654 | 15 | 0.933630 | 0.935412 | 0.910953 | 0.926665 | 0.011134 |
| 7 | 10.375836 | 0.145891 | 0.064294 | 0.029963 | True | 10 | 200 | {'bootstrap': True, 'max_depth': 10, 'n_estima... | 0.635797 | 0.631345 | 0.623886 | 0.630343 | 0.004914 | 13 | 0.947884 | 0.946548 | 0.915850 | 0.936761 | 0.014796 |
| 8 | 6.138031 | 0.142559 | 0.025194 | 0.002995 | True | 15 | 100 | {'bootstrap': True, 'max_depth': 15, 'n_estima... | 0.625111 | 0.619768 | 0.599822 | 0.614901 | 0.010883 | 23 | 0.979510 | 0.978174 | 0.977293 | 0.978326 | 0.000911 |
| 9 | 12.712065 | 0.163985 | 0.046701 | 0.001244 | True | 15 | 200 | {'bootstrap': True, 'max_depth': 15, 'n_estima... | 0.634016 | 0.623330 | 0.614082 | 0.623809 | 0.008145 | 18 | 0.979510 | 0.978619 | 0.977738 | 0.978622 | 0.000723 |
| 10 | 6.724589 | 0.325609 | 0.028120 | 0.004744 | True | 20 | 100 | {'bootstrap': True, 'max_depth': 20, 'n_estima... | 0.629564 | 0.622440 | 0.607843 | 0.619949 | 0.009041 | 20 | 0.979510 | 0.979065 | 0.977738 | 0.978771 | 0.000753 |
| 11 | 13.505669 | 0.397157 | 0.049830 | 0.000677 | True | 20 | 200 | {'bootstrap': True, 'max_depth': 20, 'n_estima... | 0.628673 | 0.612645 | 0.620321 | 0.620546 | 0.006546 | 19 | 0.979510 | 0.979065 | 0.977738 | 0.978771 | 0.000753 |
| 12 | 2.749218 | 0.104987 | 0.013929 | 0.004260 | False | 3 | 100 | {'bootstrap': False, 'max_depth': 3, 'n_estima... | 0.637578 | 0.637578 | 0.637255 | 0.637470 | 0.000152 | 1 | 0.637416 | 0.637416 | 0.640695 | 0.638509 | 0.001545 |
| 13 | 5.937420 | 0.274229 | 0.021186 | 0.001302 | False | 3 | 200 | {'bootstrap': False, 'max_depth': 3, 'n_estima... | 0.637578 | 0.637578 | 0.637255 | 0.637470 | 0.000152 | 1 | 0.637416 | 0.637416 | 0.639804 | 0.638212 | 0.001126 |
| 14 | 4.377419 | 0.017964 | 0.014611 | 0.000963 | False | 5 | 100 | {'bootstrap': False, 'max_depth': 5, 'n_estima... | 0.637578 | 0.636687 | 0.630125 | 0.634797 | 0.003323 | 8 | 0.666370 | 0.681069 | 0.696349 | 0.681263 | 0.012240 |
| 15 | 8.535188 | 0.082705 | 0.023125 | 0.001524 | False | 5 | 200 | {'bootstrap': False, 'max_depth': 5, 'n_estima... | 0.636687 | 0.636687 | 0.636364 | 0.636580 | 0.000153 | 7 | 0.668151 | 0.675724 | 0.694568 | 0.679481 | 0.011107 |
| 16 | 5.826510 | 0.076224 | 0.018134 | 0.001061 | False | 7 | 100 | {'bootstrap': False, 'max_depth': 7, 'n_estima... | 0.640249 | 0.633126 | 0.626560 | 0.633312 | 0.005590 | 10 | 0.801336 | 0.813363 | 0.804541 | 0.806414 | 0.005085 |
| 17 | 12.183470 | 0.262563 | 0.041748 | 0.007930 | False | 7 | 200 | {'bootstrap': False, 'max_depth': 7, 'n_estima... | 0.636687 | 0.634907 | 0.625668 | 0.632421 | 0.004830 | 12 | 0.802673 | 0.819154 | 0.802315 | 0.808047 | 0.007855 |
| 18 | 8.546466 | 0.218283 | 0.032218 | 0.000473 | False | 10 | 100 | {'bootstrap': False, 'max_depth': 10, 'n_estim... | 0.635797 | 0.624221 | 0.615865 | 0.625294 | 0.008173 | 16 | 0.960802 | 0.963029 | 0.954586 | 0.959472 | 0.003573 |
| 19 | 15.394531 | 0.323630 | 0.040654 | 0.001421 | False | 10 | 200 | {'bootstrap': False, 'max_depth': 10, 'n_estim... | 0.635797 | 0.626892 | 0.619430 | 0.627373 | 0.006691 | 14 | 0.964365 | 0.966147 | 0.950134 | 0.960215 | 0.007166 |
| 20 | 9.346233 | 0.071505 | 0.027116 | 0.001459 | False | 15 | 100 | {'bootstrap': False, 'max_depth': 15, 'n_estim... | 0.634016 | 0.617988 | 0.604278 | 0.618761 | 0.012153 | 22 | 0.979510 | 0.979065 | 0.977738 | 0.978771 | 0.000753 |
| 21 | 18.498514 | 0.128986 | 0.046684 | 0.000652 | False | 15 | 200 | {'bootstrap': False, 'max_depth': 15, 'n_estim... | 0.634907 | 0.625111 | 0.612299 | 0.624106 | 0.009257 | 17 | 0.979510 | 0.979065 | 0.977738 | 0.978771 | 0.000753 |
| 22 | 10.118011 | 0.112264 | 0.026858 | 0.000628 | False | 20 | 100 | {'bootstrap': False, 'max_depth': 20, 'n_estim... | 0.606411 | 0.622440 | 0.610517 | 0.613123 | 0.006798 | 24 | 0.979510 | 0.979065 | 0.977738 | 0.978771 | 0.000753 |
| 23 | 16.992818 | 0.693453 | 0.039370 | 0.000675 | False | 20 | 200 | {'bootstrap': False, 'max_depth': 20, 'n_estim... | 0.628673 | 0.622440 | 0.606061 | 0.619058 | 0.009536 | 21 | 0.979510 | 0.979065 | 0.977738 | 0.978771 | 0.000753 |
from sklearn.metrics import accuracy_score
y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.68
The accuracy of the best performing random forest on the validation set is very bad: sligthly above the fraction of the dataset that belongs to the most represented class. However, the performance on the training set is good: the models are overfitting. Later on, to regularize training, we will try to select and to extract features, to reduce noise in the training data and hopefully improve performance.
To visualize the impact of different parameters on the validation performance, we use histograms.
from typing import Dict
def plot_hyperparameter_search_results(
all_results: pd.DataFrame, param_grid: Dict, score_metric="mean_test_score"
):
"""
Plot the effect of each hyperparameter on the validation score. Fixed hyperparameters are averaged over.
:param all_results: dataframe with search result, as output by hyperparameter_search
:param param_grid: dictionary of all hyperparameters and their values
:param score_metric: metric to plot
"""
fig, axes = plt.subplots(len(param_grid), 1, figsize=(10, 3 * len(param_grid)))
for ax, (param, values) in zip(axes, param_grid.items()):
means = all_results.groupby(f"param_{param}")[score_metric].mean()
print(means)
# i want a bar for each value of the hyperparameter
ax.bar(means.index, means.values)
# ax.hist(x=means.index, weights=means.values, bins=len(values), rwidth=0.8)
ax.set_title(f"Effect of {param} on {score_metric}")
ax.set_xlabel(param)
ax.set_ylabel(score_metric)
plt.tight_layout()
plot_hyperparameter_search_results(all_results, param_grid)
param_n_estimators 100 0.627522 200 0.630021 Name: mean_test_score, dtype: float64 param_max_depth 3 0.637396 5 0.636356 7 0.633164 10 0.627151 15 0.620394 20 0.618169 Name: mean_test_score, dtype: float64 param_bootstrap False 0.628314 True 0.629230 Name: mean_test_score, dtype: float64
We use a confusion matrix to understand the types of errors our best performing random forest is making.
from sklearn.metrics import confusion_matrix
import seaborn as sns
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Other Supervised methods¶
from sklearn.linear_model import LogisticRegression
model = LogisticRegression(max_iter=10000)
param_grid = {
"penalty": ['l2', None],
"C": [0.001, 0.005, 0.01, 0.1],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
param_grid, search_type='grid', cv=3,
verbose=2)
y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
plot_hyperparameter_search_results(all_results, param_grid)
plt.show()
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 8 candidates, totalling 24 fits
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn( /Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn( /Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn(
[CV] END ................................C=0.001, penalty=l2; total time= 0.2s [CV] END ................................C=0.001, penalty=l2; total time= 0.1s [CV] END ................................C=0.001, penalty=l2; total time= 0.2s [CV] END ................................C=0.005, penalty=l2; total time= 0.3s [CV] END ................................C=0.005, penalty=l2; total time= 0.3s [CV] END ..............................C=0.001, penalty=None; total time= 0.8s [CV] END ..............................C=0.001, penalty=None; total time= 0.7s [CV] END ..............................C=0.001, penalty=None; total time= 0.8s [CV] END ................................C=0.005, penalty=l2; total time= 0.2s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn( /Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn( /Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn(
[CV] END .................................C=0.01, penalty=l2; total time= 0.3s [CV] END ..............................C=0.005, penalty=None; total time= 0.8s [CV] END .................................C=0.01, penalty=l2; total time= 0.3s [CV] END ..............................C=0.005, penalty=None; total time= 1.0s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn( /Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn(
[CV] END ..............................C=0.005, penalty=None; total time= 1.0s [CV] END .................................C=0.01, penalty=l2; total time= 0.3s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn(
[CV] END ...............................C=0.01, penalty=None; total time= 0.6s [CV] END ..................................C=0.1, penalty=l2; total time= 0.5s [CV] END ...............................C=0.01, penalty=None; total time= 0.7s [CV] END ...............................C=0.01, penalty=None; total time= 0.6s [CV] END ..................................C=0.1, penalty=l2; total time= 0.3s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn( /Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn( /Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters warnings.warn(
[CV] END ..................................C=0.1, penalty=l2; total time= 0.5s [CV] END ................................C=0.1, penalty=None; total time= 0.5s [CV] END ................................C=0.1, penalty=None; total time= 0.6s [CV] END ................................C=0.1, penalty=None; total time= 0.5s Validation Score with best hyperparameters: 0.6555819477434679 Test accuracy of best model: 0.68 param_penalty l2 0.60726 Name: mean_test_score, dtype: float64 param_C 0.001 0.603770 0.005 0.594864 0.010 0.590857 0.100 0.581205 Name: mean_test_score, dtype: float64
from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier()
param_grid = {
'n_neighbors': [1, 3, 5, 10, 15],
'weights': ['uniform', 'distance'],
'p': [1, 2]
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
param_grid, search_type='grid', cv=3,
verbose=2)
y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
plot_hyperparameter_search_results(all_results, param_grid)
plt.show()
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 20 candidates, totalling 60 fits [CV] END ................n_neighbors=1, p=1, weights=uniform; total time= 1.2s [CV] END ...............n_neighbors=1, p=1, weights=distance; total time= 1.2s [CV] END ................n_neighbors=1, p=1, weights=uniform; total time= 1.2s [CV] END ................n_neighbors=1, p=1, weights=uniform; total time= 1.2s [CV] END ................n_neighbors=1, p=2, weights=uniform; total time= 0.2s [CV] END ................n_neighbors=1, p=2, weights=uniform; total time= 0.1s [CV] END ...............n_neighbors=1, p=2, weights=distance; total time= 0.1s [CV] END ................n_neighbors=1, p=2, weights=uniform; total time= 0.2s [CV] END ...............n_neighbors=1, p=2, weights=distance; total time= 0.1s [CV] END ...............n_neighbors=1, p=2, weights=distance; total time= 0.1s [CV] END ...............n_neighbors=1, p=1, weights=distance; total time= 1.6s [CV] END ...............n_neighbors=1, p=1, weights=distance; total time= 1.6s [CV] END ................n_neighbors=3, p=1, weights=uniform; total time= 1.4s [CV] END ................n_neighbors=3, p=1, weights=uniform; total time= 1.3s [CV] END ................n_neighbors=3, p=1, weights=uniform; total time= 1.2s [CV] END ...............n_neighbors=3, p=1, weights=distance; total time= 1.2s [CV] END ................n_neighbors=3, p=2, weights=uniform; total time= 0.2s [CV] END ................n_neighbors=3, p=2, weights=uniform; total time= 0.2s [CV] END ...............n_neighbors=3, p=2, weights=distance; total time= 0.2s [CV] END ................n_neighbors=3, p=2, weights=uniform; total time= 0.2s [CV] END ...............n_neighbors=3, p=1, weights=distance; total time= 1.9s [CV] END ...............n_neighbors=3, p=2, weights=distance; total time= 0.2s [CV] END ...............n_neighbors=3, p=1, weights=distance; total time= 1.9s [CV] END ...............n_neighbors=3, p=2, weights=distance; total time= 0.2s [CV] END ................n_neighbors=5, p=1, weights=uniform; total time= 1.5s [CV] END ................n_neighbors=5, p=1, weights=uniform; total time= 1.5s [CV] END ...............n_neighbors=5, p=1, weights=distance; total time= 1.5s [CV] END ................n_neighbors=5, p=1, weights=uniform; total time= 1.5s [CV] END ................n_neighbors=5, p=2, weights=uniform; total time= 0.1s [CV] END ................n_neighbors=5, p=2, weights=uniform; total time= 0.2s [CV] END ...............n_neighbors=5, p=2, weights=distance; total time= 0.2s [CV] END ................n_neighbors=5, p=2, weights=uniform; total time= 0.2s [CV] END ...............n_neighbors=5, p=2, weights=distance; total time= 0.2s [CV] END ...............n_neighbors=5, p=2, weights=distance; total time= 0.2s [CV] END ...............n_neighbors=5, p=1, weights=distance; total time= 1.8s [CV] END ...............n_neighbors=5, p=1, weights=distance; total time= 1.8s [CV] END ...............n_neighbors=10, p=1, weights=uniform; total time= 1.6s [CV] END ...............n_neighbors=10, p=1, weights=uniform; total time= 1.5s [CV] END ..............n_neighbors=10, p=1, weights=distance; total time= 1.1s [CV] END ...............n_neighbors=10, p=1, weights=uniform; total time= 1.1s [CV] END ...............n_neighbors=10, p=2, weights=uniform; total time= 0.1s [CV] END ...............n_neighbors=10, p=2, weights=uniform; total time= 0.1s [CV] END ..............n_neighbors=10, p=2, weights=distance; total time= 0.1s [CV] END ...............n_neighbors=10, p=2, weights=uniform; total time= 0.2s [CV] END ..............n_neighbors=10, p=2, weights=distance; total time= 0.1s [CV] END ..............n_neighbors=10, p=2, weights=distance; total time= 0.1s [CV] END ..............n_neighbors=10, p=1, weights=distance; total time= 1.2s [CV] END ..............n_neighbors=10, p=1, weights=distance; total time= 1.2s [CV] END ...............n_neighbors=15, p=1, weights=uniform; total time= 1.5s [CV] END ...............n_neighbors=15, p=1, weights=uniform; total time= 1.5s [CV] END ...............n_neighbors=15, p=1, weights=uniform; total time= 1.6s [CV] END ..............n_neighbors=15, p=1, weights=distance; total time= 1.6s [CV] END ...............n_neighbors=15, p=2, weights=uniform; total time= 0.1s [CV] END ...............n_neighbors=15, p=2, weights=uniform; total time= 0.1s [CV] END ..............n_neighbors=15, p=2, weights=distance; total time= 0.2s [CV] END ...............n_neighbors=15, p=2, weights=uniform; total time= 0.2s [CV] END ..............n_neighbors=15, p=2, weights=distance; total time= 0.1s [CV] END ..............n_neighbors=15, p=2, weights=distance; total time= 0.1s [CV] END ..............n_neighbors=15, p=1, weights=distance; total time= 1.4s [CV] END ..............n_neighbors=15, p=1, weights=distance; total time= 1.5s Validation Score with best hyperparameters: 0.6460807600950119 Test accuracy of best model: 0.63 param_n_neighbors 1 0.558192 3 0.582242 5 0.592410 10 0.593970 15 0.611267 Name: mean_test_score, dtype: float64 param_weights distance 0.587468 uniform 0.587765 Name: mean_test_score, dtype: float64 param_p 1 0.588537 2 0.586696 Name: mean_test_score, dtype: float64
from sklearn.svm import SVC
model = SVC()
param_grid = {
'C': [0.001, 0.005, 0.01, 0.1],
'gamma': ['scale', 'auto'],
"kernel": ['rbf', 'poly', 'sigmoid', 'linear'],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
param_grid, search_type='grid', cv=3,
verbose=2)
y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
plot_hyperparameter_search_results(all_results, param_grid)
plt.show()
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 32 candidates, totalling 96 fits [CV] END ..................C=0.001, gamma=scale, kernel=poly; total time= 4.5s [CV] END ...................C=0.001, gamma=scale, kernel=rbf; total time= 4.5s [CV] END ...................C=0.001, gamma=scale, kernel=rbf; total time= 4.5s [CV] END ...................C=0.001, gamma=scale, kernel=rbf; total time= 4.5s [CV] END ..................C=0.001, gamma=scale, kernel=poly; total time= 4.5s [CV] END ...............C=0.001, gamma=scale, kernel=sigmoid; total time= 4.3s [CV] END ...............C=0.001, gamma=scale, kernel=sigmoid; total time= 4.4s [CV] END ..................C=0.001, gamma=scale, kernel=poly; total time= 4.6s [CV] END ...............C=0.001, gamma=scale, kernel=sigmoid; total time= 4.1s [CV] END ................C=0.001, gamma=scale, kernel=linear; total time= 4.3s [CV] END ................C=0.001, gamma=scale, kernel=linear; total time= 4.3s [CV] END ................C=0.001, gamma=scale, kernel=linear; total time= 4.3s [CV] END ....................C=0.001, gamma=auto, kernel=rbf; total time= 4.4s [CV] END ...................C=0.001, gamma=auto, kernel=poly; total time= 4.1s [CV] END ....................C=0.001, gamma=auto, kernel=rbf; total time= 4.2s [CV] END ....................C=0.001, gamma=auto, kernel=rbf; total time= 4.3s [CV] END ...................C=0.001, gamma=auto, kernel=poly; total time= 4.4s [CV] END ...................C=0.001, gamma=auto, kernel=poly; total time= 4.4s [CV] END ................C=0.001, gamma=auto, kernel=sigmoid; total time= 4.2s [CV] END ................C=0.001, gamma=auto, kernel=sigmoid; total time= 4.3s [CV] END ................C=0.001, gamma=auto, kernel=sigmoid; total time= 4.1s [CV] END .................C=0.001, gamma=auto, kernel=linear; total time= 4.1s [CV] END .................C=0.001, gamma=auto, kernel=linear; total time= 4.2s [CV] END .................C=0.001, gamma=auto, kernel=linear; total time= 4.2s [CV] END ..................C=0.005, gamma=scale, kernel=poly; total time= 4.7s [CV] END ...................C=0.005, gamma=scale, kernel=rbf; total time= 5.3s [CV] END ...................C=0.005, gamma=scale, kernel=rbf; total time= 5.3s [CV] END ...................C=0.005, gamma=scale, kernel=rbf; total time= 5.2s [CV] END ..................C=0.005, gamma=scale, kernel=poly; total time= 4.9s [CV] END ...............C=0.005, gamma=scale, kernel=sigmoid; total time= 4.6s [CV] END ...............C=0.005, gamma=scale, kernel=sigmoid; total time= 4.7s [CV] END ..................C=0.005, gamma=scale, kernel=poly; total time= 5.2s [CV] END ...............C=0.005, gamma=scale, kernel=sigmoid; total time= 4.5s [CV] END ................C=0.005, gamma=scale, kernel=linear; total time= 4.4s [CV] END ................C=0.005, gamma=scale, kernel=linear; total time= 4.3s [CV] END ................C=0.005, gamma=scale, kernel=linear; total time= 4.2s [CV] END ...................C=0.005, gamma=auto, kernel=poly; total time= 4.2s [CV] END ....................C=0.005, gamma=auto, kernel=rbf; total time= 5.1s [CV] END ....................C=0.005, gamma=auto, kernel=rbf; total time= 4.9s [CV] END ....................C=0.005, gamma=auto, kernel=rbf; total time= 5.0s [CV] END ...................C=0.005, gamma=auto, kernel=poly; total time= 4.6s [CV] END ...................C=0.005, gamma=auto, kernel=poly; total time= 5.0s [CV] END ................C=0.005, gamma=auto, kernel=sigmoid; total time= 4.6s [CV] END ................C=0.005, gamma=auto, kernel=sigmoid; total time= 4.6s [CV] END ................C=0.005, gamma=auto, kernel=sigmoid; total time= 4.8s [CV] END .................C=0.005, gamma=auto, kernel=linear; total time= 4.6s [CV] END .................C=0.005, gamma=auto, kernel=linear; total time= 4.6s [CV] END .................C=0.005, gamma=auto, kernel=linear; total time= 4.4s [CV] END ...................C=0.01, gamma=scale, kernel=poly; total time= 4.6s [CV] END ....................C=0.01, gamma=scale, kernel=rbf; total time= 5.6s [CV] END ....................C=0.01, gamma=scale, kernel=rbf; total time= 5.3s [CV] END ....................C=0.01, gamma=scale, kernel=rbf; total time= 5.3s [CV] END ...................C=0.01, gamma=scale, kernel=poly; total time= 4.6s [CV] END ...................C=0.01, gamma=scale, kernel=poly; total time= 4.8s [CV] END ................C=0.01, gamma=scale, kernel=sigmoid; total time= 4.2s [CV] END ................C=0.01, gamma=scale, kernel=sigmoid; total time= 4.3s [CV] END ................C=0.01, gamma=scale, kernel=sigmoid; total time= 4.8s [CV] END .................C=0.01, gamma=scale, kernel=linear; total time= 4.7s [CV] END .................C=0.01, gamma=scale, kernel=linear; total time= 4.5s [CV] END .................C=0.01, gamma=scale, kernel=linear; total time= 4.4s [CV] END ....................C=0.01, gamma=auto, kernel=poly; total time= 4.7s [CV] END .....................C=0.01, gamma=auto, kernel=rbf; total time= 5.6s [CV] END .....................C=0.01, gamma=auto, kernel=rbf; total time= 5.2s [CV] END .....................C=0.01, gamma=auto, kernel=rbf; total time= 5.3s [CV] END ....................C=0.01, gamma=auto, kernel=poly; total time= 4.6s [CV] END ....................C=0.01, gamma=auto, kernel=poly; total time= 5.1s [CV] END .................C=0.01, gamma=auto, kernel=sigmoid; total time= 5.0s [CV] END .................C=0.01, gamma=auto, kernel=sigmoid; total time= 4.9s [CV] END .................C=0.01, gamma=auto, kernel=sigmoid; total time= 5.0s [CV] END ..................C=0.01, gamma=auto, kernel=linear; total time= 4.5s [CV] END ..................C=0.01, gamma=auto, kernel=linear; total time= 4.5s [CV] END ..................C=0.01, gamma=auto, kernel=linear; total time= 4.5s [CV] END ....................C=0.1, gamma=scale, kernel=poly; total time= 4.6s [CV] END .....................C=0.1, gamma=scale, kernel=rbf; total time= 5.7s [CV] END .....................C=0.1, gamma=scale, kernel=rbf; total time= 5.5s [CV] END .....................C=0.1, gamma=scale, kernel=rbf; total time= 5.2s [CV] END ....................C=0.1, gamma=scale, kernel=poly; total time= 4.7s [CV] END ....................C=0.1, gamma=scale, kernel=poly; total time= 4.9s [CV] END .................C=0.1, gamma=scale, kernel=sigmoid; total time= 4.5s [CV] END .................C=0.1, gamma=scale, kernel=sigmoid; total time= 4.5s [CV] END .................C=0.1, gamma=scale, kernel=sigmoid; total time= 4.5s [CV] END ..................C=0.1, gamma=scale, kernel=linear; total time= 5.9s [CV] END ..................C=0.1, gamma=scale, kernel=linear; total time= 6.2s [CV] END ..................C=0.1, gamma=scale, kernel=linear; total time= 6.0s [CV] END ......................C=0.1, gamma=auto, kernel=rbf; total time= 4.8s [CV] END .....................C=0.1, gamma=auto, kernel=poly; total time= 4.4s [CV] END ......................C=0.1, gamma=auto, kernel=rbf; total time= 5.0s [CV] END ......................C=0.1, gamma=auto, kernel=rbf; total time= 5.1s [CV] END .....................C=0.1, gamma=auto, kernel=poly; total time= 4.4s [CV] END .....................C=0.1, gamma=auto, kernel=poly; total time= 4.9s [CV] END ..................C=0.1, gamma=auto, kernel=sigmoid; total time= 4.4s [CV] END ..................C=0.1, gamma=auto, kernel=sigmoid; total time= 4.6s [CV] END ..................C=0.1, gamma=auto, kernel=sigmoid; total time= 4.2s [CV] END ...................C=0.1, gamma=auto, kernel=linear; total time= 6.0s [CV] END ...................C=0.1, gamma=auto, kernel=linear; total time= 5.4s [CV] END ...................C=0.1, gamma=auto, kernel=linear; total time= 5.1s Validation Score with best hyperparameters: 0.6484560570071259 Test accuracy of best model: 0.68 param_C 0.001 0.636950 0.005 0.631903 0.010 0.629083 0.100 0.621734 Name: mean_test_score, dtype: float64 param_gamma auto 0.629918 scale 0.629918 Name: mean_test_score, dtype: float64 param_kernel linear 0.607779 poly 0.636951 rbf 0.637470 sigmoid 0.637470 Name: mean_test_score, dtype: float64
scGPT embeddings¶
scGPT is a foundation model created by Cui et al. ("scGPT: Towards Building a Foundation Model for Single-Cell Multi-omics Using Generative AI.") for single-cell genomics that can be used to generate embeddings for genes and cells. We use its embeddings as input data for a random forest classifier.
# constant definition and utility functions
set_seed(42)
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
n_hvg = 1200
n_bins = 51
mask_value = -1
pad_value = -2
n_input_bins = n_bins
device = torch.device("cpu")
def preprocess_df(name, extension='csv'):
"""
Preparing dataset for scGPT
"""
csv_file = f'data/{name}.{extension}'
df = pd.read_csv(csv_file)
df = df.drop(columns=['is_true', "Variant_Classification", "mutation"])
df = df.clip(lower=0)
df.columns = [s.split("..")[0] for s in df.columns]
df.to_csv(f'data/{name}_processed_gpt.{extension}', index=False)
def load_model(model_dir="model_params"):
# the weights can be downloaded from https://github.com/bowang-lab/scGPT/tree/main?tab=readme-ov-file#pretrained-scGPT-checkpoints
# we used the whole human model
model_config_file = model_dir + "/args.json"
model_file = model_dir + "/best_model.pt"
vocab_file = model_dir + "/vocab.json"
vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
if s not in vocab:
vocab.append_token(s)
with open(model_config_file, "r") as f:
model_configs = json.load(f)
print(
f"Resume model from {model_file}, the model args will override the "
f"config {model_config_file}."
)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]
ntokens = len(vocab) # size of vocabulary
model = TransformerModel(
ntokens,
embsize,
nhead,
d_hid,
nlayers,
vocab=vocab,
pad_value=pad_value,
n_input_bins=n_input_bins,
)
try:
model.load_state_dict(torch.load(model_file, map_location=device))
print(f"Loading all model params from {model_file}")
except:
# only load params that are in the model and match the size
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file, map_location=device)
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if k in model_dict and v.shape == model_dict[k].shape
}
for k, v in pretrained_dict.items():
print(f"Loading params {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.to(device)
gene2idx = vocab.get_stoi()
return model, gene2idx, vocab_file
def load_and_preprocess(name, extension='csv'):
adata = sc.read(f"data/{name}_processed_gpt.{extension}", cache=False)
ori_batch_col = "batch"
data_is_raw = True
# Preprocess the data following the scGPT data pre-processing pipeline
preprocessor = Preprocessor(
use_key="X", # the key in adata.layers to use as raw data
filter_gene_by_counts=3, # step 1
filter_cell_by_counts=False, # step 2
normalize_total=1e4, # 3. whether to normalize the raw data and to what sum
result_normed_key="X_normed", # the key in adata.layers to store the normalized data
log1p=data_is_raw, # 4. whether to log1p the normalized data
result_log1p_key="X_log1p",
subset_hvg=n_hvg, # 5. whether to subset the raw data to highly variable genes
hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
binning=n_bins, # 6. whether to bin the raw data and to what number of bins
result_binned_key="X_binned", # the key in adata.layers to store the binned data
)
preprocessor(adata)
return adata
def compute_gene_embeddings(gene2idx, model_gpt, adata):
"""
Compute embedding of each gene in adata
"""
gene_ids = np.array([id for id in gene2idx.values()])
gene_embeddings = model_gpt.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device))
gene_embeddings = gene_embeddings.detach().cpu().numpy()
# Filter on the intersection between the Immune Human HVGs found in step 1.2 and scGPT's 30+K foundation model vocab
gene_embeddings = {gene: gene_embeddings[i] for i, gene in enumerate(gene2idx.keys()) if
gene in adata.var.index.tolist()}
print('Retrieved gene embeddings for {} genes.'.format(len(gene_embeddings)))
return gene_embeddings
def get_cell_embeddings(adata, gene_embeddings):
"""
Compute cell embeddings of adata
"""
cell_embeddings_l = []
for cell_idx in tqdm.tqdm(range(adata.shape[0])):
cell_expression = adata[cell_idx].X.toarray().flatten()
cell_embedding = np.zeros_like(next(iter(gene_embeddings.values())))
for gene_idx, expression_level in enumerate(cell_expression):
gene_name = adata.var.index[gene_idx]
if gene_name in gene_embeddings:
cell_embedding += gene_embeddings[gene_name] * expression_level
cell_embeddings_l.append(cell_embedding)
cell_embeddings = np.array(cell_embeddings_l)
print('Computed embeddings for {} cells.'.format(cell_embeddings.shape[0]))
return cell_embeddings
We preprocess the data, load the model and compute the embeddings.
# Load the pre-trained scGPT model and preprocess the data
model_gpt, gene2idx, vocab_file = load_model()
preprocess_df("CCLE_labels")
preprocess_df("TCGA_labels")
Resume model from model_params/best_model.pt, the model args will override the config model_params/args.json. Loading params encoder.embedding.weight with shape torch.Size([60697, 512]) Loading params encoder.enc_norm.weight with shape torch.Size([512]) Loading params encoder.enc_norm.bias with shape torch.Size([512]) Loading params value_encoder.linear1.weight with shape torch.Size([512, 1]) Loading params value_encoder.linear1.bias with shape torch.Size([512]) Loading params value_encoder.linear2.weight with shape torch.Size([512, 512]) Loading params value_encoder.linear2.bias with shape torch.Size([512]) Loading params value_encoder.norm.weight with shape torch.Size([512]) Loading params value_encoder.norm.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.0.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.0.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.0.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.0.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.0.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.0.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.0.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.0.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.1.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.1.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.1.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.1.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.1.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.1.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.1.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.1.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.1.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.1.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.2.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.2.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.2.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.2.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.2.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.2.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.2.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.2.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.2.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.2.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.3.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.3.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.3.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.3.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.3.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.3.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.3.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.3.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.3.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.3.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.4.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.4.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.4.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.4.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.4.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.4.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.4.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.4.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.4.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.4.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.5.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.5.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.5.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.5.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.5.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.5.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.5.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.5.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.5.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.5.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.6.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.6.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.6.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.6.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.6.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.6.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.6.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.6.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.6.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.6.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.7.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.7.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.7.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.7.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.7.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.7.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.7.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.7.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.7.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.7.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.8.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.8.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.8.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.8.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.8.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.8.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.8.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.8.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.8.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.8.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.9.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.9.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.9.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.9.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.9.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.9.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.9.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.9.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.9.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.9.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.10.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.10.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.10.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.10.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.10.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.10.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.10.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.10.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.10.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.10.norm2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.11.self_attn.out_proj.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.11.self_attn.out_proj.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.11.linear1.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.11.linear1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.11.linear2.weight with shape torch.Size([512, 512]) Loading params transformer_encoder.layers.11.linear2.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.11.norm1.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.11.norm1.bias with shape torch.Size([512]) Loading params transformer_encoder.layers.11.norm2.weight with shape torch.Size([512]) Loading params transformer_encoder.layers.11.norm2.bias with shape torch.Size([512]) Loading params decoder.fc.0.weight with shape torch.Size([512, 512]) Loading params decoder.fc.0.bias with shape torch.Size([512]) Loading params decoder.fc.2.weight with shape torch.Size([512, 512]) Loading params decoder.fc.2.bias with shape torch.Size([512]) Loading params decoder.fc.4.weight with shape torch.Size([1, 512]) Loading params decoder.fc.4.bias with shape torch.Size([1])
adataCCLE = load_and_preprocess("CCLE_labels")
adataTCGA = load_and_preprocess("TCGA_labels")
gene_embeddings_ccle = compute_gene_embeddings(gene2idx, model_gpt,adataCCLE)
gene_embeddings_tcga = compute_gene_embeddings(gene2idx, model_gpt, adataTCGA)
cell_embeddings_ccle = get_cell_embeddings(adataCCLE, gene_embeddings_ccle)
cell_embeddings_tcga = get_cell_embeddings(adataTCGA, gene_embeddings_tcga)
X_train_embedded, X_test_embedded, y_train, y_test = train_test_split(cell_embeddings_tcga, y, test_size=0.2, random_state=0)
X_val_embedded, X_test_embedded, y_val, y_test = train_test_split(X_test_embedded, y_test, test_size=0.5, random_state=0)
scGPT - INFO - Filtering genes by counts ... scGPT - INFO - Normalizing total counts ... scGPT - INFO - Log1p transforming ... scGPT - INFO - Subsetting highly variable genes ... scGPT - WARNING - No batch_key is provided, will use all cells for HVG selection. scGPT - INFO - Binning data ... scGPT - INFO - Filtering genes by counts ... scGPT - INFO - Normalizing total counts ... scGPT - INFO - Log1p transforming ... scGPT - INFO - Subsetting highly variable genes ... scGPT - WARNING - No batch_key is provided, will use all cells for HVG selection. scGPT - INFO - Binning data ... Retrieved gene embeddings for 579 genes. Retrieved gene embeddings for 545 genes.
100%|██████████| 924/924 [00:02<00:00, 443.69it/s]
Computed embeddings for 924 cells.
100%|██████████| 4211/4211 [00:08<00:00, 494.44it/s]
Computed embeddings for 4211 cells.
model = RandomForestClassifier()
param_grid = {
'n_estimators': [100, 200],
'max_depth': [3, 5, 7, 10, 15, 20],
'bootstrap': [True, False],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train_embedded, y_train, X_val_embedded, y_val,
param_grid, search_type='grid', cv=3,
verbose=2)
y_pred = best_model.predict(X_test_embedded)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
plot_hyperparameter_search_results(all_results, param_grid)
plt.show()
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 24 candidates, totalling 72 fits Validation Score with best hyperparameters: 0.6484560570071259 Test accuracy of best model: 0.68 param_n_estimators 100 0.623564 200 0.624950 Name: mean_test_score, dtype: float64 param_max_depth 3 0.636728 5 0.634204 7 0.627673 10 0.622624 15 0.614681 20 0.609634 Name: mean_test_score, dtype: float64 param_bootstrap False 0.624233 True 0.624282 Name: mean_test_score, dtype: float64
As we can see, the results are completely in line with the previous ones, with the accuracy of the best model still being only slightly above randomly guessing the most represented class.
Feature Selection using correlations¶
Import previously computed features sorted by correletion with the labels
import json
with open('good_genes_missense.txt') as f:
ordered_genes = json.load(f)
N_values = [10, 25, 50, 75, 100, 125, 150, 200, 300]
val_scores = []
train_scores = []
for N in N_values:
good_genes = ordered_genes[:N]
X_train_good = X_train[good_genes]
X_val_good = X_val[good_genes]
model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
model.fit(X_train_good, y_train)
val_scores.append(model.score(X_val_good, y_val))
train_scores.append(model.score(X_train_good, y_train))
plt.plot(N_values, val_scores, label='Validation')
plt.plot(N_values, train_scores, label='Train')
plt.xlabel('Number of genes')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy vs number of genes selected')
plt.show()
N = 75
good_genes = ordered_genes[:N]
X_train_good = X_train[good_genes]
X_val_good = X_val[good_genes]
model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
model.fit(X_train_good, y_train)
X_test_good = X_test[good_genes]
model.score(X_test_good, y_test)
0.6469194312796208
y_pred = model.predict(X_test_good)
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Feature Extraction using PCA¶
N_values = [10, 25, 50, 75, 100, 125, 150, 200, 300]
val_scores = []
train_scores = []
for N in N_values:
pca = PCA(n_components=N)
pca.fit(X_train)
X_train_pca = pca.transform(X_train)
X_val_pca = pca.transform(X_val)
model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
model.fit(X_train_pca, y_train)
val_scores.append(model.score(X_val_pca, y_val))
train_scores.append(model.score(X_train_pca, y_train))
plt.plot(N_values, val_scores, label='Validation')
plt.plot(N_values, train_scores, label='Train')
plt.xlabel('Number of PCA components')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy vs number of PCA components')
plt.show()
N = 25
pca = PCA(n_components=N)
pca.fit(X_train)
X_train_pca = pca.transform(X_train)
X_val_pca = pca.transform(X_val)
model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
model.fit(X_train_pca, y_train)
X_test_pca = pca.transform(X_test)
model.score(X_test_pca, y_test)
0.6658767772511849
y_pred = model.predict(X_test_pca)
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Class weights¶
class_weights = {
1: len(y_train) / (y_train == 1).sum(),
0: len(y_train) / (y_train == 0).sum()
}
N = 75
good_genes = ordered_genes[:N]
X_train_good = X_train[good_genes]
X_val_good = X_val[good_genes]
model = RandomForestClassifier(n_estimators=200, max_depth=5, bootstrap=True, class_weight=class_weights)
model.fit(X_train_good, y_train)
X_test_good = X_test[good_genes]
model.score(X_test_good, y_test)
0.514218009478673
y_pred = model.predict(X_test_good)
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
By weighting the loss terms of different classed inversely to the class proportions, we have obtained more balanced predictions. However, for each given class still about 50% of the time the model is wrong.
Convolutional Graph Neural Network (GCN)¶
We try another approach, by modelling each cell as a graph, where nodes will be represented by all the genes of the dataset. We use due to lack of computational resources, a very simple embedding for each node, that is a vector of dimension 1 given by the expression level of such gene for the given cell. The edges of the graph (and their weights) will be based on the correlation value between the genes of given cell. So three inputs are given: topology of the graph, edge embeddings, node embeddings, the first two constructed with respect to the entire dataset, the latter is cell specific. The problem is then, a Graph classification task, where it will help to classify its Mutation, believing that exploiting the correlations of genes and the message parsing layers will rapidly spread the necessary info accross all the nodes, allowing to correctly classify the cell mutation.
import torch
import matplotlib.pyplot as plt
from torch.nn.functional import tanh, log_softmax
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
import pandas as pd
import numpy as np
import networkx as nx
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import scipy.sparse as sp
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
First let's define the structure of our GNN
#definition of GCN
class GCN(torch.nn.Module):
def __init__(self, n_features, emb_size, n_classes, n_mp_layers):
"""
Parameters:
n_features: is node embedding size
emb_size: is the embedding size of each node, to be used in the message parsing steps
n_classes: for the classification task.
n_mp_layers: is the number of message parsing steps to perform
"""
super(GCN, self).__init__()
self.n_mp_layers = n_mp_layers
self.convs = torch.nn.ModuleList()
self.convs.append(GCNConv(n_features, emb_size)) #trasforms node features to embedding vectors
for i in range(1, n_mp_layers):
self.convs.append(GCNConv(emb_size, emb_size)) #message parsing step
self.out = torch.nn.Sequential( #post-message parsing
Linear(emb_size, n_classes) #converts into a vector of same size as the number of classes
)
def forward(self, batch):
#data is an element of pytorch dataset, x is the feature matrix of size (n_nodes X embedding_dim), edge_index is the adjacency list, batch indicates the nodes that belong to given graph
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
out = x
for i in range(self.n_mp_layers-1):
out = self.convs[i](out, edge_index)
out = tanh(out)
out = self.convs[self.n_mp_layers-1](out, edge_index)
out = global_mean_pool(out, batch) #graph level pooling
out = self.out(out)
return log_softmax(out, dim=1)
#for TCGA
csv_label_file = 'data/TCGA_labels.csv'
df = pd.read_csv(csv_label_file)
scaler = StandardScaler()
Y = df['mutation']
Y = Y.apply(lambda x: 1 if x == 'Missense_Mutation' else 0)
X = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
X_norm = pd.DataFrame(scaler.fit_transform(X), columns=X.columns)
# Split the data into train and test sets
X_train, X_test, Y_train, Y_test = train_test_split(X_norm, Y, test_size=0.2, random_state=42)
corr_matrix = X_train.corr()
corr_matrix[abs(corr_matrix)<0.05] = 0
# Create sparse adjacency matrix
adj_matrix = sp.coo_matrix(corr_matrix.values)
edge_index, edge_attr = from_scipy_sparse_matrix(adj_matrix)
X_train = np.array(X_train)
X_test = np.array(X_test)
Y_train = np.array(Y_train)
Y_test = np.array(Y_test)
# Create PyTorch Geometric data objects for training set
tr_data = []
for i in range(len(X_train)):
x = torch.tensor(X_train[i], dtype=torch.float).reshape([-1, 1])
y = torch.tensor(Y_train[i], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
tr_data.append(data)
# Create PyTorch Geometric data objects for test set
test_data = []
for i in range(len(X_test)):
x = torch.tensor(X_test[i], dtype=torch.float).reshape([-1, 1])
y = torch.tensor(Y_test[i], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
test_data.append(data)
# DataLoader for batch processing
tr_loader = DataLoader(tr_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
# Model, criterion, optimizer
model = GCN(1, 1, 2, 2)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Training loop
model.train()
for epoch in range(1): #for computational reasons
for batch in tr_loader:
optimizer.zero_grad()
out = model(batch)
loss = criterion(out, batch.y)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
# Evaluation loop
model.eval()
correct = 0
for batch in test_loader:
with torch.no_grad():
out = model(batch)
pred = out.argmax(dim=1)
correct += (pred == batch.y).sum().item()
val_score = correct / len(Y_test)
print(f'Test Accuracy: {val_score:.4f}')
Epoch 1, Loss: 0.4633154571056366 Test Accuracy: 0.6418
It can be observed that the final accuracy is not too different from the one observed with the other classifiers we used
#for CCLE
csv_label_file = 'data/CCLE_labels.csv'
df = pd.read_csv(csv_label_file)
scaler = StandardScaler()
Y = df['mutation']
Y = Y.apply(lambda x: 1 if x == 'Missense_Mutation' else 0)
X = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
X_norm = pd.DataFrame(scaler.fit_transform(X), columns=X.columns)
# Split the data into train and test sets
X_train, X_test, Y_train, Y_test = train_test_split(X_norm, Y, test_size=0.2, random_state=42)
corr_matrix = X_train.corr()
corr_matrix[abs(corr_matrix)<0.05] = 0
# Create sparse adjacency matrix
adj_matrix = sp.coo_matrix(corr_matrix.values)
edge_index, edge_attr = from_scipy_sparse_matrix(adj_matrix)
X_train = np.array(X_train)
X_test = np.array(X_test)
Y_train = np.array(Y_train)
Y_test = np.array(Y_test)
# Create PyTorch Geometric data objects for training set
tr_data = []
for i in range(len(X_train)):
x = torch.tensor(X_train[i], dtype=torch.float).reshape([-1, 1])
y = torch.tensor(Y_train[i], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
tr_data.append(data)
# Create PyTorch Geometric data objects for test set
test_data = []
for i in range(len(X_test)):
x = torch.tensor(X_test[i], dtype=torch.float).reshape([-1, 1])
y = torch.tensor(Y_test[i], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
test_data.append(data)
# DataLoader for batch processing
tr_loader = DataLoader(tr_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
# Model, criterion, optimizer
model = GCN(1, 1, 2, 2)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Training loop
model.train()
for epoch in range(1): #for computational reasons
for batch in tr_loader:
optimizer.zero_grad()
out = model(batch)
loss = criterion(out, batch.y)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
# Evaluation loop
model.eval()
correct = 0
for batch in test_loader:
with torch.no_grad():
out = model(batch)
pred = out.argmax(dim=1)
correct += (pred == batch.y).sum().item()
val_score = correct / len(Y_test)
print(f'Test Accuracy: {val_score:.4f}')
Epoch 1, Loss: 0.41184601187705994 Test Accuracy: 0.6919
Again we can notice, that the obtained accuracy levels are coherent with results that have been shown before It would be highly interesting to observe the outcomes achievable by combining SCGPT embedding for each gene in the Graph Neural Network. However, due to their costly and resource-intensive computational requirements, we are unable to do so.
Conclusions¶
From our analysis, it emerged that the data is too noisy to classify any of the two tasks with accuracy meaningfully better than what obtained by always outputting the most frequent label. This is in line with what was found in the paper. Our models learn, but they overfit the training set and achieve poor generalization performance. This is in spite of the fact that we explicitly regularized our models through hyperparameter choices in our grid searches.
Furthermore, our attempts at feature extraction and selection through PCA and correlation analysis were not helpful in improving performance. We were expecting the regularizing effect of these methods, which remove some of the least informative features and should thus help reducing noise and mitigating overfitting, to help with our tasks. We conclude that there is no strong enough signal in the data to solve our tasks with meaningful accuracy. Indeed, from inspection of the confusion matrices many of our models end up simply outputting the most frequent class every time.
Even providing weights inversely proportional to the class frequencies to weigh the loss during training, although we get more balanced predictions, the performance is terrible, with each class being mispredicted about 50% of the time.
Appendix¶
We show here some results for the other task and dataset combinations. As anticipated, the conclusions are essentially the same as those presented for classification of mutation type on the TCGA dataset. For this reason, we avoid including in this report all our explorations and all the models we trained, to avoid being repetitive and making this notebook too heavy.
CCLE Dataset¶
csv_file = 'data/CCLE_labels.csv'
df = pd.read_csv(csv_file)
df_full = df
df.head()
| Variant_Classification | MAD1L1..ENSG00000002822. | ITGA3..ENSG00000005884. | MYH13..ENSG00000006788. | GAS7..ENSG00000007237. | REV3L..ENSG00000009413. | TSPAN9..ENSG00000011105. | RNF216..ENSG00000011275. | CEP68..ENSG00000011523. | BRCA1..ENSG00000012048. | ... | BGLAP..ENSG00000242252. | MICAL3..ENSG00000243156. | FMN1..ENSG00000248905. | GATC..ENSG00000257218. | CUX1..ENSG00000257923. | BAHCC1..ENSG00000266074. | PRAG1..ENSG00000275342. | UHRF1..ENSG00000276043. | is_true | mutation | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1__639V_URINARY_TRACT_Missense_Mutation_c.(742... | 3102 | 8389 | 0 | 3 | 3104 | 1698 | 5130 | 1687 | 4486 | ... | 9 | 8144 | 3539 | 2944 | 6288 | 3160 | 140 | 23040 | False | Missense_Mutation |
| 1 | 1__BL41_HAEMATOPOIETIC_AND_LYMPHOID_TISSUE_Mis... | 5645 | 312 | 0 | 14 | 4925 | 35 | 4856 | 1110 | 10004 | ... | 33 | 8346 | 40 | 4463 | 23703 | 133 | 4115 | 38105 | False | Missense_Mutation |
| 2 | 1__CA46_HAEMATOPOIETIC_AND_LYMPHOID_TISSUE_Mis... | 6967 | 2113 | 0 | 113 | 8180 | 115 | 3648 | 2871 | 7615 | ... | 34 | 6574 | 234 | 4132 | 18149 | 265 | 1580 | 4790 | False | Missense_Mutation |
| 3 | 1__CAL29_URINARY_TRACT_Missense_Mutation_c.(84... | 1882 | 24720 | 0 | 41 | 1809 | 1731 | 4349 | 759 | 2988 | ... | 9 | 4380 | 1202 | 2996 | 19658 | 924 | 3914 | 10313 | False | Missense_Mutation |
| 4 | 1__CI1_HAEMATOPOIETIC_AND_LYMPHOID_TISSUE_Miss... | 3139 | 1444 | 0 | 619 | 4744 | 25 | 6039 | 1757 | 12484 | ... | 14 | 6447 | 44 | 3915 | 4945 | 25 | 4066 | 20850 | False | Missense_Mutation |
5 rows × 582 columns
Mutation type¶
df = deepcopy(df_full)
df = log_and_normalize(df)
y = df['mutation']
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = y.apply(lambda x: 1 if x == 'Missense_Mutation' else 0)
(y == 1).sum() / len(y)
0.6536796536796536
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier()
param_grid = {
'n_estimators': [100, 200],
'max_depth': [3, 5, 7, 10, 15, 20],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
param_grid, search_type='grid', cv=3,
verbose=1)
Fitting 3 folds for each of 12 candidates, totalling 36 fits Validation Score with best hyperparameters: 0.6847826086956522
best_score, best_params
(0.6495232766092843, {'max_depth': 3, 'n_estimators': 200})
from sklearn.metrics import accuracy_score
y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.66
from sklearn.metrics import confusion_matrix
import seaborn as sns
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Functional vs Dysfunctional¶
df = deepcopy(df_full)
df = log_and_normalize(df)
y = df['is_true']
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = y.apply(lambda x: 1 if x == True else 0)
(y == 1).sum() / len(y)
0.32575757575757575
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier()
param_grid = {
'n_estimators': [100, 200],
'max_depth': [3, 5, 7, 10, 15, 20],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
param_grid, search_type='grid', cv=3,
verbose=1)
Fitting 3 folds for each of 12 candidates, totalling 36 fits Validation Score with best hyperparameters: 0.7282608695652174
best_score, best_params
(0.669821050437225, {'max_depth': 3, 'n_estimators': 100})
from sklearn.metrics import accuracy_score
y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.68
from sklearn.metrics import confusion_matrix
import seaborn as sns
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
TCGA Dataset¶
csv_file = 'data/TCGA_labels.csv'
df = pd.read_csv(csv_file)
df_full = df
df.head()
| Variant_Classification | ABCB9..ENSG00023457 | ABLIM1..ENSG0003983 | ACTA2..ENSG00059 | ACTB..ENSG00060 | ADORA2B..ENSG000136 | ADRB2..ENSG000154 | AEBP2..ENSG000121536 | AEN..ENSG00064782 | AGAP1..ENSG000116987 | ... | ZCCHC2..ENSG00054877 | ZDHHC14..ENSG00079683 | ZFP36L1..ENSG000677 | ZMAT3..ENSG00064393 | ZMIZ1..ENSG00057178 | ZMIZ2..ENSG00083637 | ZMYND8..ENSG00023613 | ZNF561..ENSG00093134 | is_true | mutation | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... | 376.831000 | 1358.86000 | 2471.580000 | 143602.00000 | 159.674000 | 63.136500 | 946.639000 | 626.477000 | 344.195000 | ... | 323.344000 | 75.356400 | 8558.040000 | 43.991900 | 1783.300000 | 5320.570000 | 1018.330000 | 821.181000 | True | Frame_Shift_Ins |
| 1 | A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... | 198.244448 | 5367.62179 | 2528.570328 | 77726.97678 | 19.656121 | 2.579692 | 2130.976296 | 732.991931 | 386.605718 | ... | 228.638412 | 322.247574 | 6446.509718 | 36.542642 | 3207.438557 | 3213.116903 | 1688.261865 | 1149.407697 | True | In_Frame_Del |
| 2 | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | 117.516000 | 1936.34000 | 14533.700000 | 185841.00000 | 95.490700 | 191.866000 | 766.578000 | 256.410000 | 239.611000 | ... | 230.672000 | 121.132000 | 12726.800000 | 74.270600 | 2496.910000 | 4005.300000 | 923.961000 | 391.689000 | True | Frame_Shift_Del |
| 3 | A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... | 60.747000 | 5667.60000 | 3560.420000 | 107645.00000 | 86.834700 | 1047.620000 | 698.413000 | 186.741000 | 262.372000 | ... | 638.609000 | 343.604000 | 8024.280000 | 78.431400 | 3746.030000 | 2692.810000 | 1168.070000 | 670.402000 | True | Frame_Shift_Del |
| 4 | A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... | 327.477000 | 1096.61000 | 3430.480000 | 64166.60000 | 51.837300 | 9.491300 | 706.010000 | 1617.540000 | 821.366000 | ... | 806.811000 | 124.118000 | 1350.690000 | 237.649000 | 1885.860000 | 2283.400000 | 1967.630000 | 480.043000 | True | Frame_Shift_Del |
5 rows × 554 columns
Functional vs Dysfunctional¶
df = deepcopy(df_full)
df = log_and_normalize(df)
y = df['is_true']
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = y.apply(lambda x: 1 if x == True else 0)
(y == 1).sum() / len(y)
0.4490619805271907
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier()
param_grid = {
'n_estimators': [100, 200],
'max_depth': [3, 5, 7, 10, 15, 20],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
param_grid, search_type='grid', cv=3,
verbose=1)
Fitting 3 folds for each of 12 candidates, totalling 36 fits Validation Score with best hyperparameters: 0.5629453681710214
best_score, best_params
(0.5463172397591757, {'max_depth': 3, 'n_estimators': 200})
from sklearn.metrics import accuracy_score
y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.57
from sklearn.metrics import confusion_matrix
import seaborn as sns
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()